AgentSkillsCN

ml-model-selection

根据问题特性、数据属性与约束条件,系统性地选择最适合的机器学习算法。

SKILL.md
--- frontmatter
name: ml-model-selection
description: Systematic methodology for choosing the most appropriate machine learning algorithm based on problem characteristics, data properties, and constraints

ML Model Selection Framework

What It Is

A systematic methodology for choosing the most appropriate machine learning algorithm based on problem characteristics, data properties, computational constraints, and interpretability requirements. Rather than defaulting to the latest deep learning architecture or most familiar model, this framework matches algorithms to problem structure through a series of diagnostic questions and empirical validation.

When to Use It

  • Starting a new ML project before investing in model development
  • Current model underperforms and considering alternatives
  • Need to justify model choice to stakeholders or in research papers
  • Multiple algorithms seem plausible but unclear which will work best
  • Balancing accuracy vs interpretability vs computational cost
  • Deploying to production with latency or resource constraints

Execution Steps

1. Define the Problem Type

Classify your ML task into one of the core problem categories. This immediately narrows the algorithm space by eliminating entire families of incompatible methods.

Action: Answer: Is this supervised or unsupervised? If supervised: classification (discrete output) or regression (continuous output)? If classification: binary, multiclass, or multilabel? If unsupervised: clustering, dimensionality reduction, or anomaly detection? Write: "Problem type: {category} with {input_type} → {output_type}."

2. Characterize the Data

Examine dataset properties that constrain algorithm choice. Size matters: 100 samples require different approaches than 1M samples. Feature types matter: tabular vs images vs text vs time series.

Action: Document: Sample count (n), feature count (p), n<<p or n>>p? Data types: numeric, categorical, text, images, time series, mixed? Class balance (for classification)? Missing values percentage? Noise level (clean vs noisy)? Outliers present?

3. Identify Constraints

Real-world deployments impose non-negotiable constraints that override pure accuracy optimization. Identify hard constraints before evaluating models.

Action: List constraints with thresholds: Training time budget (minutes, hours, days)? Inference latency requirement (<1ms, <100ms, <1s)? Memory footprint limit? Interpretability required (black box acceptable vs must explain predictions)? Online learning needed? Edge deployment (mobile, IoT)?

4. Apply Decision Heuristics

Use problem-specific rules to generate candidate models. These heuristics encode decades of practitioner experience and eliminate obviously poor choices.

Action: Apply filters:

  • Small data (n<1000): Linear models, decision trees, k-NN (avoid deep learning)
  • Tabular data: Gradient boosting (XGBoost, LightGBM), random forests (rarely neural nets)
  • Images: CNNs (ResNet, EfficientNet), vision transformers for large datasets
  • Text: Transformers (BERT, GPT), simpler methods (TF-IDF + logistic regression) if n<10K
  • Time series: ARIMA/Prophet (univariate), LSTMs/Transformers (multivariate, long sequences)
  • Imbalanced classes: Use resampling, class weights, or algorithms robust to imbalance (random forests, XGBoost)

5. Shortlist and Benchmark

Narrow to 2-4 candidate models and empirically validate on your data. Theory guides, but data decides. Use proper cross-validation to avoid overfitting to validation set.

Action: Implement baseline (logistic regression or decision tree), 2-3 sophisticated candidates. Use k-fold cross-validation (k=5 or 10). Measure: accuracy/F1/AUC for classification, RMSE/MAE/R² for regression, training time, inference time, memory usage. Plot learning curves to diagnose high bias vs high variance.

6. Optimize Top Candidate

Once a winner emerges from benchmarking, invest in hyperparameter tuning. Don't tune all models exhaustively - focus effort on the most promising one.

Action: Use grid search (small spaces), random search (medium spaces), or Bayesian optimization (large/expensive spaces) for hyperparameter tuning. Monitor for overfitting (train vs validation metric gap). If validation metric plateaus, try: more data, feature engineering, ensemble methods, or revisit model choice.

Real-World Applications

Structured/Tabular Data (Finance, Healthcare, Marketing)

  • Credit scoring: Gradient boosting (XGBoost) for best accuracy, logistic regression for interpretability
  • Customer churn prediction: Random forests for handling mixed data types and non-linearities
  • Fraud detection: Isolation forests or one-class SVM for anomaly detection with severe imbalance

Computer Vision

  • Image classification: Transfer learning with pre-trained CNNs (ResNet, EfficientNet) for most applications
  • Object detection: YOLO or Faster R-CNN depending on speed vs accuracy trade-off
  • Medical imaging: CNNs with attention mechanisms, often requiring domain-specific architectures

Natural Language Processing

  • Sentiment analysis: BERT fine-tuning for high accuracy, TF-IDF + logistic regression for speed/simplicity
  • Named entity recognition: CRF for small datasets, transformer models (BERT, RoBERTa) for large datasets
  • Text generation: GPT-family models, with model size dictated by computational budget

Recommender Systems

  • Collaborative filtering: Matrix factorization (SVD) for explicit feedback, neural collaborative filtering for implicit feedback
  • Content-based: Cosine similarity on TF-IDF or embeddings for cold-start scenarios
  • Hybrid: Two-tower neural networks combining collaborative and content signals

Time Series Forecasting

  • Univariate forecasting: ARIMA, Prophet for interpretability and seasonality handling
  • Multivariate forecasting: LSTMs, Temporal Fusion Transformers for complex dependencies
  • Anomaly detection: Autoencoders, LSTM-based sequence models

Anti-Patterns

Deep learning for small datasets → Overfits dramatically with n<1000; use simpler models with stronger regularization.

Linear models for highly non-linear relationships → Misses complex interactions; try tree-based methods or add polynomial features.

Complex models without baseline → No reference point to judge improvement; always start with simple baseline (mean predictor, logistic regression).

Ignoring class imbalance → Model predicts majority class always; apply resampling, adjust class weights, or use appropriate metrics (F1, AUC, not accuracy).

Optimizing accuracy when other metrics matter → Minimizing false positives vs false negatives has different costs; optimize for the business metric (precision, recall, profit).

One-size-fits-all model selection → Different data requires different models; tabular data ≠ images ≠ text in terms of optimal architecture.

Success Metrics

  • Model performance exceeds baseline by meaningful margin (>5% relative improvement)
  • Meets production constraints (latency, memory, throughput)
  • Generalizes to held-out test set (validation metric ≈ test metric)
  • Learning curves show healthy convergence (no extreme overfitting or underfitting)
  • Cost-benefit analysis favors deployment (accuracy gain justifies complexity)

Related Frameworks

  • Bias-Variance Tradeoff: Diagnosing underfitting vs overfitting to guide model complexity
  • No Free Lunch Theorem: Understanding that no single model is universally best
  • Ensemble Methods: Combining models when single model selection is insufficient
  • AutoML: Automated model selection and hyperparameter tuning (H2O, Auto-sklearn, Google AutoML)

Common Pitfalls

  • Choosing model based on hype rather than problem fit (deep learning everywhere syndrome)
  • Not checking for data leakage (train-test contamination inflates metrics)
  • Overfitting to validation set through excessive hyperparameter tuning
  • Ignoring computational budget until production deployment fails
  • Using default hyperparameters without tuning (especially for tree-based methods)
  • Forgetting to scale features for distance-based algorithms (k-NN, SVM, neural nets)
  • Comparing models on different train/test splits (always use same CV folds)

Tools & Resources

  • Benchmarking: scikit-learn (Python), caret (R), MLlib (Spark)
  • AutoML: Auto-sklearn, TPOT, H2O AutoML, Google Vertex AI AutoML
  • Model Registries: MLflow, Weights & Biases, Neptune.ai for experiment tracking
  • References: "Hands-On Machine Learning" (Géron), "AI Engineering" (Chip Huyen), "Elements of Statistical Learning" (Hastie, Tibshirani, Friedman)
  • Decision Aids: scikit-learn algorithm cheat sheet, Azure ML algorithm flowchart, Google ML decision tree

Model Selection Decision Tree

Start Here: What is n (sample count)?

n < 100: Use simple models (linear regression, logistic regression, k-NN with k≈√n)

100 < n < 1K: Linear models with regularization (Ridge, Lasso), decision trees, k-NN

1K < n < 100K, tabular: Gradient boosting (XGBoost, LightGBM, CatBoost), random forests

1K < n < 100K, images/audio: Transfer learning from pre-trained CNNs

1K < n < 100K, text: TF-IDF + logistic regression or small transformer models

n > 100K, tabular: Gradient boosting, deep neural networks for extreme complexity

n > 100K, images: CNNs from scratch or fine-tuned (ResNet, EfficientNet, ViT)

n > 100K, text: Fine-tune transformers (BERT, RoBERTa, GPT) or train from scratch if n>10M

Then consider: Interpretability required? → Use linear models, decision trees, or SHAP for complex models. Latency critical? → Favor simpler/smaller models, avoid deep architectures. Memory constrained? → Avoid ensembles and large neural networks.


Framework Type: Model Selection Methodology Domain: Machine Learning, Data Science Practitioner Score: 9/10 - Essential skill, directly impacts project success, learned through experience Complexity: Medium - Requires understanding of multiple algorithm families and their trade-offs Prerequisites: Supervised learning basics, cross-validation, bias-variance tradeoff, common ML algorithms