MLflow 是一个用于机器学习实验追踪与模型版本管理的工具集,旨在解决实验可复现性差和模型管理混乱的问题。其核心概念包括 MLflow Tracking、MLflow Projects 和 MLflow Model Registry。MLflow Tracking 记录实验参数、指标和 artifact,MLflow Projects 打包代码以在任何环境中运行,而 MLflow Model Registry 则管理模型的版本和生命周期。通过安装 MLflow 和 scikit-learn,读者可以快速上手,使用 `mlflow ui` 查看所有实验。MLflow 的 autolog 功能可以自动记录多种机器学习库的实验信息,极大简化了实验追踪过程。读者将学会如何一次训练多个模型并在 UI 中按指标排序以找到最优模型。此外,Model Registry 允许用户将模型从实验阶段推进到生产阶段,并通过更改阶段来轻松回滚模型。结合 DVC 进行数据集版本管理,MLflow 提供了完整的可复现性解决方案,包括代码、数据、依赖和随机种子的管理。读者将能够有效地管理机器学习实验和模型生命周期,确保实验结果的可复现性和模型的可控性。
实验追踪与模型版本管理
每个 ML 工程师都遇到过这个噩梦:
"我上周训练的那个 92% 准确率的模型去哪儿了?超参记不清了,数据好像也不是这个版本..."
MLflow 解决这一切: 把实验参数、代码、数据、模型统一管理,让 ML 实验可复现。
MLflow 三件套
- MLflow Tracking: 记录实验参数 + 指标 + artifact
- MLflow Projects: 打包代码, 任何环境能跑
- MLflow Model Registry: 模型版本管理 + 阶段流转
安装: pip install mlflow scikit-learn
5 行开启 Tracking
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 启动一个 experiment
mlflow.set_experiment("iris-classifier")
with mlflow.start_run():
# 记录超参数
n_estimators = 100
max_depth = 5
mlflow.log_param("n_estimators", n_estimators)
mlflow.log_param("max_depth", max_depth)
# 训练
model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
model.fit(X_train, y_train)
# 评估
acc = accuracy_score(y_test, model.predict(X_test))
mlflow.log_metric("accuracy", acc)
# 保存模型
mlflow.sklearn.log_model(model, "model")
print(f"Run ID: {mlflow.active_run().info.run_id}, Accuracy: {acc:.2%}")
运行后: mlflow ui 打开 http://localhost:5000 看所有实验。
自动 log 一切 (autolog)
不用手动 log_param / log_metric, MLflow 一键全记:
import mlflow
mlflow.sklearn.autolog() # 训练 sklearn 时自动 log
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
# 全部参数、指标、模型文件都自动保存!
支持的库: sklearn / XGBoost / LightGBM / PyTorch / TensorFlow / Keras / Spark MLlib / Fastai
一次训 50 个模型 + 找最优
import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
mlflow.sklearn.autolog()
# 网格搜索 + 自动 log
param_grid = {
"n_estimators": [50, 100, 200],
"max_depth": [3, 5, 10, None],
}
grid = GridSearchCV(
RandomForestClassifier(),
param_grid, cv=5, scoring="accuracy"
)
grid.fit(X_train, y_train)
print(f"Best: {grid.best_params_}, {grid.best_score_:.2%}")
autolog 会记录每组参数的结果, 之后在 UI 里按 metric 排序, 找最优。
Model Registry: 模型从实验到生产
光有 tracking 还不够, 还要管理模型生命周期 (Staging → Production → Archived):
from mlflow import MlflowClient
client = MlflowClient()
# 1. 注册模型 (从某个 run)
model_uri = f"runs:/{run_id}/model"
mlflow.register_model(model_uri, "iris-classifier")
# 2. 提升到生产
client.transition_model_version_stage(
name="iris-classifier",
version=1,
stage="Production"
)
# 3. 加载生产模型
prod_model = mlflow.pyfunc.load_model("models:/iris-classifier/Production")
pred = prod_model.predict(X_test)
完整流程图
[开发阶段] [上线阶段]
Train v1 → log Register → Stage: None
Train v2 → log Register → Stage: Staging
Train v3 → log Test OK → Stage: Production
Bad → Stage: Archived
模型 Registry 让 "哪个模型在线上" 一目了然, 回滚也只是改 stage。
配合 DVC: 数据集版本管理
模型是数据的派生, 还要管数据。DVC (Data Version Control) 是 Git 的搭档:
# 1. 初始化
dvc init
# 2. 把 data.csv 加到 DVC
dvc add data/training.csv
git add data/training.csv.dvc data/.gitignore
git commit -m "Add training data v1"
# 3. 数据变了
dvc add data/training.csv
git commit -m "Update to training data v2"
# 4. 切回老数据
git checkout HEAD~1 -- data/training.csv.dvc
dvc checkout
DVC 把大文件存到 S3 / OSS, .dvc 文件存指针到 Git。
5 个工程实践
- 每次跑都写
mlflow.start_run(), 不跑就丢失 - 代码 commit hash 记到 run:
mlflow.set_tag("git_sha", sha) - 数据 hash 也记:
mlflow.set_tag("data_hash", md5(file)) - 模型 schema 记: 输入输出列名、类型
- 别用默认 SQLite: 生产用
mlflow server --backend-store-uri postgresql://...
与其他工具对比
| 工具 | 优势 | 劣势 |
|---|---|---|
| MLflow | 通用、API 简单、本地优先 | 大规模需要自己部署 |
| Weights & Biases | UI 漂亮、协作强 | SaaS, 收费 |
| Neptune.ai | 团队协作、对比视图 | SaaS, 收费 |
| TensorBoard | 深度学习集成 | 主打可视化, 不是 tracking |
| DVC | 数据 + 模型 + pipeline | 学习曲线陡 |
小结
- MLflow = Tracking (实验) + Projects (代码) + Registry (模型生命周期)
autolog()一键记录 sklearn / XGBoost / PyTorch- Model Registry 用 stage (Staging / Production) 管理线上模型
- 配合 DVC 管数据, Git 管代码
- 可复现 = 代码 hash + 数据 hash + 依赖 hash + 随机种子
练习思考
- 跑 5 个不同超参的 RandomForest, 在 MLflow UI 里比较 accuracy, 哪组最好?
- 注册 2 个模型版本, 把 v1 提升到 Production, 加载并预测, 跟 v2 的预测结果对比。
- MLflow autolog 跟手动 log_param 有什么区别? 什么时候必须用手动?
章末小测验
检验你对《实验追踪与模型版本管理》的掌握程度。
MLflow Tracking 主要记录什么?
Model Registry 中模型的生命周期阶段是?