ML 学习站
跳到正文

训练过程的统计监控

用控制图监控 loss/accuracy, 提前发现训练异常。

30 分钟6 / 62,881
加载中...

本章探讨了如何利用控制图(Control Chart)对机器学习模型的训练过程进行实时监控,以解决传统监控方法中无法及时发现训练异常的问题。核心概念包括:1)控制图的基本原理,即通过3σ法则判断数据是否在控制范围内;2)四种主要控制图类型(均值控制图、个值控制图、移动极差图、比例控制图),分别适用于不同类型的监控指标;3)Western Electric规则,提供比单一3σ法则更敏感的异常检测方法。学习本章后,读者能够设计并应用控制图来实时监控训练过程中的多个关键指标,如loss、accuracy、梯度范数等,从而及时发现并应对训练中的异常情况,减少训练时间并提高模型稳定性。

训练过程的统计监控

本章问题: 训练跑了一晚, 第二天发现 loss 早就 plateau 了 5 小时。怎么"实时"知道训练是否正常? 答案: 控制图 (Control Chart), 工业质量管理用了一个世纪的方法。

1. 训练监控的现状

大多数 ML 工程师的监控:

  • ❌ 只看 loss/accuracy 曲线, 等训练完才发现问题
  • ❌ 没量化"正常波动"和"异常波动"
  • ❌ 多卡训练时, 不知道哪张卡异常

工业质量管理 (SPC)控制图 解决了完全相同的问题: 监控"过程"是否在控。

2. 控制图基础 (Control Chart)

2.1 3σ 法则

如果过程在控, 数据点应该在 μ ± 3σ 范围内。超出 = 异常。

            UCL (Upper Control Limit) = μ + 3σ
                   │
         ┌─────────┼─────────┐
         │  正常区域 │ 异常  │ ← 1 个点出 UCL → 异常
   ───────┼─────────┼─────────┼─────
         │         │         │
         │  (中线) │         │
         │         │         │
         └─────────┼─────────┘
                   │
            LCL = μ - 3σ

2.2 控制图的 2 步构建

  1. 收集基线: 在"已知正常"的训练运行中, 记录 20+ 个数据点 (如每个 epoch 的 loss)
  2. 计算控制限: μ, σ, UCL = μ+3σ, LCL = μ-3σ
  3. 持续监控: 实时点 vs 控制限, 出界 = 告警
import numpy as np
import matplotlib.pyplot as plt

# 1. 基线 (假设前 20 个 epoch 是"正常训练")
np.random.seed(42)
baseline_loss = np.random.normal(0.5, 0.05, 20)  # 20 个正常值
mu = baseline_loss.mean()
sigma = baseline_loss.std(ddof=1)
ucl = mu + 3 * sigma
lcl = mu - 3 * sigma

# 2. 实时数据 (训练中)
new_loss = np.concatenate([
    np.random.normal(0.5, 0.05, 30),  # 继续正常
    np.random.normal(1.0, 0.1, 5),   # 突然发散!
])

# 3. 画控制图
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(baseline_loss, "o-", color="blue", label="基线 (正常)")
ax.axhline(mu, color="green", linestyle="--", label=f"μ = {mu:.3f}")
ax.axhline(ucl, color="red", linestyle="--", label=f"UCL = {ucl:.3f}")
ax.axhline(lcl, color="red", linestyle="--", label=f"LCL = {lcl:.3f}")
ax.plot(range(20, 20+len(new_loss)), new_loss, "s-", color="orange", label="实时")
# 标异常
ax.axvline(20 + 30, color="red", alpha=0.3, label="异常开始")
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss")
ax.set_title("控制图: 监控训练 Loss")
ax.legend(); ax.grid(True, alpha=0.3)
plt.show()

3. 4 种控制图

3.1 均值控制图 (X-bar Chart)

监控连续变量的均值

3.2 个值控制图 (I Chart / X Chart)

监控单个值 (适合小批量训练, 每步一个数据点)

3.3 移动极差图 (MR Chart)

监控相邻差值 (波动大小), 跟 X Chart 配对

3.4 比例控制图 (P Chart)

监控比例 (如每批的"准确率")

class ControlChart:
    """完整的控制图工具"""
    def __init__(self, baseline):
        self.mu = np.mean(baseline)
        self.sigma = np.std(baseline, ddof=1)
        self.ucl = self.mu + 3 * self.sigma
        self.lcl = self.mu - 3 * self.sigma
        # 移动极差
        mr = np.abs(np.diff(baseline))
        self.mr_bar = mr.mean()
        # I Chart σ 估计
        self.sigma_i = self.mr_bar / 1.128  # d2 (n=2)
    
    def check_point(self, value):
        """检查单点是否异常"""
        if value > self.ucl or value < self.lcl:
            return "OUT_OF_CONTROL", abs(value - self.mu) / self.sigma_i
        return "OK", abs(value - self.mu) / self.sigma_i
    
    def check_run(self, values):
        """检查整个序列, 返回异常点索引"""
        return [i for i, v in enumerate(values) if self.check_point(v)[0] == "OUT_OF_CONTROL"]
    
    def western_electric_rules(self, values):
        """Western Electric 4 规则: 更敏感的异常检测"""
        alerts = []
        for i in range(len(values)):
            v = values[i]
            # 规则 1: 1 个点出 3σ
            if v > self.ucl or v < self.lcl:
                alerts.append(i)
            # 规则 2: 连续 9 个点在中心同侧
            if i >= 8:
                last9 = values[i-8:i+1]
                if all(x > self.mu for x in last9) or all(x < self.mu for x in last9):
                    alerts.append(i)
            # 规则 3: 连续 6 个点递增/递减
            if i >= 5:
                last6 = values[i-5:i+1]
                if all(last6[j] < last6[j+1] for j in range(5)):
                    alerts.append(i)  # 上升趋势
                if all(last6[j] > last6[j+1] for j in range(5)):
                    alerts.append(i)  # 下降趋势
        return list(set(alerts))

4. 训练监控实战

4.1 监控指标设计

指标控制图异常含义
训练 loss (每个 epoch)I Chart发散/NaN
验证 accuracyI Chart过拟合
验证 loss - 训练 lossI Chart泛化能力
梯度范数I Chart爆炸/消失
权重范数I Chart初始化异常
学习率 (实际)I Chartscheduler 异常
训练步耗时I ChartI/O 瓶颈
显存使用I ChartOOM 风险
NaN/Inf 比例P Chart数值不稳定
梯度爆炸比例P Chart梯度截断失效
class TrainingMonitor:
    """训练过程实时监控"""
    def __init__(self, baseline_file=None):
        self.charts = {}
        if baseline_file:
            self.load_baseline(baseline_file)
    
    def register(self, name, baseline):
        self.charts[name] = ControlChart(baseline)
    
    def step(self, name, value, epoch):
        if name not in self.charts:
            # 第一次, 用前 10 个点当基线
            if not hasattr(self, f"_buffer_{name}"):
                setattr(self, f"_buffer_{name}", [])
            buffer = getattr(self, f"_buffer_{name}")
            buffer.append(value)
            if len(buffer) >= 10:
                self.register(name, buffer)
            return "INITIALIZING"
        
        status, z = self.charts[name].check_point(value)
        if status == "OUT_OF_CONTROL":
            print(f"⚠️ [Epoch {epoch}] {name} = {value:.4f} 异常 (z={z:.2f})")
        return status
    
    def check(self, name, values):
        """训练完一次性检查"""
        if name not in self.charts:
            return []
        return self.charts[name].check_run(values)

4.2 实时监控示例

# 模拟一次训练, 边训边监控
monitor = TrainingMonitor()

# 假设这是前 10 个 epoch 的 loss (当基线)
losses = []
for epoch in range(100):
    # 模拟 loss
    if epoch < 10:
        loss = 0.5 + np.random.normal(0, 0.02) - epoch * 0.005  # 正常下降
    elif epoch < 30:
        loss = 0.45 + np.random.normal(0, 0.02)  # 稳定
    elif epoch == 30:
        loss = 1.5  # 突然发散 (模拟)
    else:
        loss = 0.45 + np.random.normal(0, 0.02)
    
    losses.append(loss)
    status = monitor.step("loss", loss, epoch)
    if status == "OUT_OF_CONTROL":
        print(f"  建议: 停止训练, 检查数据/超参")
        break

5. Western Electric 规则: 更敏感的异常检测

1 个点出 3σ 算太宽松, 实际工业用 Western Electric 4 大规则:

规则检测
11 个点出 3σ
2连续 9 个点在中心同侧 (趋势)
3连续 6 个点递增/递减 (单调趋势)
4连续 14 个点交替上下 (振荡)
# 集成监控
def smart_monitor(losses, baseline_chart, verbose=True):
    """智能训练监控"""
    alerts = baseline_chart.western_electric_rules(losses)
    if alerts:
        if verbose:
            print(f"⚠️ 检测到异常 epoch: {alerts}")
        return alerts
    return []

# 训练: 每个 epoch 调一次
# 异常时: 发 Slack 告警 / 自动回滚 checkpoint / 调小学习率

6. 高级: 多元控制图 (Hotelling T²)

单变量控制图只能看一个指标。多变量 用 Hotelling T²:

from statsmodels.stats.multivariate import test_mvmean

# 例: 同时监控 loss, accuracy, grad_norm
def multivariate_monitor(metrics_history, baseline, alpha=0.01):
    """Hotelling T² 多变量控制"""
    n = len(baseline)
    k = baseline.shape[1]
    diff = np.array(metrics_history[-n:]) - baseline.mean(axis=0)
    cov = np.cov(baseline.T)
    T2 = n * diff @ np.linalg.inv(cov) @ diff
    # F 分布临界值
    from scipy.stats import f
    f_crit = f.ppf(1-alpha, k, n-k)
    ucl = ((n-1)**2 / n) * f_crit / n
    return T2 > ucl, T2

7. 控制图在生产环境中的"完整告警链"

import requests  # Slack/钉钉 webhook

class TrainingAlertSystem:
    """完整训练告警"""
    def __init__(self, webhook_url=None):
        self.webhook_url = webhook_url
        self.charts = {}
        self.epoch = 0
        self.checkpoints = []
    
    def on_epoch_end(self, metrics, model):
        """每个 epoch 结束调用"""
        self.epoch += 1
        alerts = []
        
        for name, value in metrics.items():
            if name in self.charts:
                status, z = self.charts[name].check_point(value)
                if status == "OUT_OF_CONTROL":
                    alerts.append(f"{name}={value:.4f} (z={z:.2f})")
        
        if alerts:
            self._handle_alert(alerts, model)
    
    def _handle_alert(self, alerts, model):
        """告警处理: 1) 通知 2) 回滚 3) 暂停"""
        # 1. Slack 通知
        if self.webhook_url:
            requests.post(self.webhook_url, json={
                "text": f"⚠️ 训练异常 (epoch {self.epoch}): {', '.join(alerts)}"
            })
        # 2. 加载上一个稳定 checkpoint
        if self.checkpoints:
            model.load_state_dict(self.checkpoints[-1])
            print(f"  已回滚到 epoch {self.epoch - 1} 的权重")
        # 3. 可选: 暂停训练, 让人工介入

8. 现代替代: TensorBoard + 自定义插件

# TensorBoard + 自定义 HParams + 控制图插件
from torch.utils.tensorboard import SummaryWriter
import numpy as np

writer = SummaryWriter("runs/exp1")
# HParams (记录超参)
writer.add_hparams(
    {"lr": 0.001, "batch_size": 32, "model": "resnet50"},
    {"hparam/accuracy": 0.92}
)
# 控制图
for epoch in range(100):
    loss = train_one_epoch()
    writer.add_scalar("Loss/train", loss, epoch)
    # 自定义告警
    if loss > 1.0 and epoch > 10:
        writer.add_scalar("Alerts/anomaly", 1, epoch)

W&B, MLflow, ClearML 等平台有内置的"训练监控告警"。

9. 实战: 完整训练监控脚本

import numpy as np
import matplotlib.pyplot as plt
from collections import deque

class RealTimeMonitor:
    """实时训练监控器"""
    def __init__(self, window=20):
        self.window = window
        self.history = deque(maxlen=1000)
        self.baseline = None
        self.alerts = []
    
    def add(self, value):
        self.history.append(value)
        if self.baseline is None and len(self.history) >= self.window:
            self.baseline = np.array(list(self.history))[:self.window]
            print(f"[监控] 基线建立: μ={self.baseline.mean():.4f}, σ={self.baseline.std():.4f}")
    
    def check(self):
        if self.baseline is None or len(self.history) < self.window:
            return None
        mu, sigma = self.baseline.mean(), self.baseline.std(ddof=1)
        current = list(self.history)[-1]
        if abs(current - mu) > 3 * sigma:
            alert = f"异常: 当前={current:.4f}, 基线 μ={mu:.4f}, 3σ 范围=[{mu-3*sigma:.4f}, {mu+3*sigma:.4f}]"
            self.alerts.append(alert)
            return alert
        return None
    
    def plot(self):
        fig, ax = plt.subplots(figsize=(12, 5))
        if self.baseline is not None:
            mu = self.baseline.mean()
            sigma = self.baseline.std(ddof=1)
            ax.axhline(mu, color="green", linestyle="--")
            ax.axhline(mu + 3*sigma, color="red", linestyle="--", label="UCL")
            ax.axhline(mu - 3*sigma, color="red", linestyle="--", label="LCL")
        ax.plot(list(self.history), "b-", alpha=0.7)
        # 标异常
        for i, _ in enumerate(self.history):
            if self.baseline is not None and i >= self.window:
                v = list(self.history)[i]
                mu = self.baseline.mean()
                sigma = self.baseline.std(ddof=1)
                if abs(v - mu) > 3 * sigma:
                    ax.plot(i, v, "rx", markersize=12)
        ax.set_xlabel("Step"); ax.set_ylabel("Loss")
        ax.set_title("训练 Loss 实时控制图")
        ax.legend(); ax.grid(True, alpha=0.3)
        plt.tight_layout(); plt.show()

# 用法
monitor = RealTimeMonitor(window=20)
for step in range(100):
    if step < 20:
        loss = 0.5 + np.random.normal(0, 0.02)  # 正常
    elif step == 50:
        loss = 2.0  # 异常
    else:
        loss = 0.45 + np.random.normal(0, 0.02)
    
    monitor.add(loss)
    alert = monitor.check()
    if alert:
        print(f"  Step {step}: {alert}")

10. 小结

你学到了关键点
控制图 3σ1 个点出 3σ 算异常
4 种控制图I / X-bar / MR / P, 选对监控类型
Western Electric4 大规则比单 3σ 更敏感
训练指标loss, acc, grad, 权重, 时间, 显存, NaN 比例
自动告警Slack/钉钉 + 自动回滚 checkpoint
Hotelling T²多变量联合监控
ML 平台W&B, MLflow, ClearML 都内置告警
业务价值减少 50% 训练时间, 提前发现 NaN

11. 习题

  1. 模拟 100 个 epoch 训练, 其中第 50-55 epoch 出现"loss spike":

    • 建立控制图 (前 20 epoch 当基线)
    • 用 Western Electric 4 规则检测异常
    • 报告: 哪些 epoch 被标异常? 用什么规则?
  2. 写一个 TrainingMonitor 类:

    • 同时监控 loss, gradient_norm, learning_rate
    • 任何一个指标出 3σ, 打印告警
    • 模拟训练, 验证类能工作
👉 查看参考答案
  1. 提示: 用前 20 epoch 当基线, 算 mu, sigma, 然后用 I Chart + Western Electric 规则。 第 50-55 epoch 会被规则 1 (出 3σ) 和规则 3 (连续递增) 标记。 Western Electric 比单 3σ 更早发现异常。

  2. 提示: 维护 3 个 ControlChart 实例, 每个 epoch 调一次 check_point。 模拟 loss spike / gradient explosion / scheduler 异常, 验证都能被捕获。

12. 下一章


📚 本章综合: 改编自 Triola《基础统计学》第 14 章统计过程控制 (SPC), 加入 ML 训练监控实战。

章末小测验

检验你对《训练过程的统计监控》的掌握程度。

1

关于控制图在机器学习训练监控中的作用,以下哪些说法是正确的?

2

以下哪些控制图类型适用于监控机器学习训练中的单个值?

3

关于Western Electric规则,以下哪些说法是正确的?

4

以下哪些指标适用于使用I Chart进行监控?

5

在机器学习训练监控中,以下哪些方法可以用于减少训练时间或提前发现问题?

讨论区(0)

加载评论中...