ML 学习站
跳到正文

实验追踪与模型版本管理

MLflow 三件套: Tracking + Projects + Registry, 让实验可复现。

40 分钟1 / 31,866
加载中...

MLflow 是一个用于机器学习实验追踪与模型版本管理的工具集,旨在解决实验可复现性差和模型管理混乱的问题。其核心概念包括 MLflow Tracking、MLflow Projects 和 MLflow Model Registry。MLflow Tracking 记录实验参数、指标和 artifact,MLflow Projects 打包代码以在任何环境中运行,而 MLflow Model Registry 则管理模型的版本和生命周期。通过安装 MLflow 和 scikit-learn,读者可以快速上手,使用 `mlflow ui` 查看所有实验。MLflow 的 autolog 功能可以自动记录多种机器学习库的实验信息,极大简化了实验追踪过程。读者将学会如何一次训练多个模型并在 UI 中按指标排序以找到最优模型。此外,Model Registry 允许用户将模型从实验阶段推进到生产阶段,并通过更改阶段来轻松回滚模型。结合 DVC 进行数据集版本管理,MLflow 提供了完整的可复现性解决方案,包括代码、数据、依赖和随机种子的管理。读者将能够有效地管理机器学习实验和模型生命周期,确保实验结果的可复现性和模型的可控性。

实验追踪与模型版本管理

每个 ML 工程师都遇到过这个噩梦:

"我上周训练的那个 92% 准确率的模型去哪儿了?超参记不清了,数据好像也不是这个版本..."

MLflow 解决这一切: 把实验参数、代码、数据、模型统一管理,让 ML 实验可复现

MLflow 三件套

  • MLflow Tracking: 记录实验参数 + 指标 + artifact
  • MLflow Projects: 打包代码, 任何环境能跑
  • MLflow Model Registry: 模型版本管理 + 阶段流转

安装: pip install mlflow scikit-learn

5 行开启 Tracking

import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 启动一个 experiment
mlflow.set_experiment("iris-classifier")

with mlflow.start_run():
    # 记录超参数
    n_estimators = 100
    max_depth = 5
    mlflow.log_param("n_estimators", n_estimators)
    mlflow.log_param("max_depth", max_depth)
    
    # 训练
    model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
    model.fit(X_train, y_train)
    
    # 评估
    acc = accuracy_score(y_test, model.predict(X_test))
    mlflow.log_metric("accuracy", acc)
    
    # 保存模型
    mlflow.sklearn.log_model(model, "model")
    
    print(f"Run ID: {mlflow.active_run().info.run_id}, Accuracy: {acc:.2%}")

运行后: mlflow ui 打开 http://localhost:5000 看所有实验。

自动 log 一切 (autolog)

不用手动 log_param / log_metric, MLflow 一键全记:

import mlflow

mlflow.sklearn.autolog()  # 训练 sklearn 时自动 log

with mlflow.start_run():
    model = RandomForestClassifier(n_estimators=100, max_depth=5)
    model.fit(X_train, y_train)
    # 全部参数、指标、模型文件都自动保存!

支持的库: sklearn / XGBoost / LightGBM / PyTorch / TensorFlow / Keras / Spark MLlib / Fastai

一次训 50 个模型 + 找最优

import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

mlflow.sklearn.autolog()

# 网格搜索 + 自动 log
param_grid = {
    "n_estimators": [50, 100, 200],
    "max_depth": [3, 5, 10, None],
}
grid = GridSearchCV(
    RandomForestClassifier(),
    param_grid, cv=5, scoring="accuracy"
)
grid.fit(X_train, y_train)

print(f"Best: {grid.best_params_}, {grid.best_score_:.2%}")

autolog 会记录每组参数的结果, 之后在 UI 里按 metric 排序, 找最优。

Model Registry: 模型从实验到生产

光有 tracking 还不够, 还要管理模型生命周期 (Staging → Production → Archived):

from mlflow import MlflowClient

client = MlflowClient()

# 1. 注册模型 (从某个 run)
model_uri = f"runs:/{run_id}/model"
mlflow.register_model(model_uri, "iris-classifier")

# 2. 提升到生产
client.transition_model_version_stage(
    name="iris-classifier",
    version=1,
    stage="Production"
)

# 3. 加载生产模型
prod_model = mlflow.pyfunc.load_model("models:/iris-classifier/Production")
pred = prod_model.predict(X_test)

完整流程图

[开发阶段]                    [上线阶段]
Train v1 → log                Register → Stage: None
Train v2 → log                Register → Stage: Staging
Train v3 → log                Test OK → Stage: Production
                              Bad → Stage: Archived

模型 Registry 让 "哪个模型在线上" 一目了然, 回滚也只是改 stage。

配合 DVC: 数据集版本管理

模型是数据的派生, 还要管数据。DVC (Data Version Control) 是 Git 的搭档:

# 1. 初始化
dvc init

# 2. 把 data.csv 加到 DVC
dvc add data/training.csv
git add data/training.csv.dvc data/.gitignore
git commit -m "Add training data v1"

# 3. 数据变了
dvc add data/training.csv
git commit -m "Update to training data v2"

# 4. 切回老数据
git checkout HEAD~1 -- data/training.csv.dvc
dvc checkout

DVC 把大文件存到 S3 / OSS, .dvc 文件存指针到 Git。

5 个工程实践

  1. 每次跑都写 mlflow.start_run(), 不跑就丢失
  2. 代码 commit hash 记到 run: mlflow.set_tag("git_sha", sha)
  3. 数据 hash 也记: mlflow.set_tag("data_hash", md5(file))
  4. 模型 schema 记: 输入输出列名、类型
  5. 别用默认 SQLite: 生产用 mlflow server --backend-store-uri postgresql://...

与其他工具对比

工具优势劣势
MLflow通用、API 简单、本地优先大规模需要自己部署
Weights & BiasesUI 漂亮、协作强SaaS, 收费
Neptune.ai团队协作、对比视图SaaS, 收费
TensorBoard深度学习集成主打可视化, 不是 tracking
DVC数据 + 模型 + pipeline学习曲线陡

小结

  • MLflow = Tracking (实验) + Projects (代码) + Registry (模型生命周期)
  • autolog() 一键记录 sklearn / XGBoost / PyTorch
  • Model Registry 用 stage (Staging / Production) 管理线上模型
  • 配合 DVC 管数据, Git 管代码
  • 可复现 = 代码 hash + 数据 hash + 依赖 hash + 随机种子

练习思考

  1. 跑 5 个不同超参的 RandomForest, 在 MLflow UI 里比较 accuracy, 哪组最好?
  2. 注册 2 个模型版本, 把 v1 提升到 Production, 加载并预测, 跟 v2 的预测结果对比。
  3. MLflow autolog 跟手动 log_param 有什么区别? 什么时候必须用手动?

章末小测验

检验你对《实验追踪与模型版本管理》的掌握程度。

1

MLflow Tracking 主要记录什么?

2

Model Registry 中模型的生命周期阶段是?

讨论区(0)

加载评论中...