ML 学习站
跳到正文

迁移学习:站在巨人肩膀上

特征提取 vs 微调 vs 渐进解冻, 实战花卉分类。

35 分钟2 / 41,797
加载中...

迁移学习是一种利用在大数据集上预训练模型并在小数据集上进行微调的技术,显著优于从头训练模型。其核心思想在于利用预训练模型学习到的通用特征和任务相关特征。迁移学习的三种主要策略包括特征提取、微调和渐进解冻。特征提取适用于数据极少且与预训练任务相似的情况,通过冻结预训练模型的主干部分,仅训练新的分类头。微调适用于数据量中等且与预训练任务相似的情况,解冻整个模型但使用较小的学习率。渐进解冻则适用于数据极少且需要最大化性能的情况,从后往前逐层解冻模型。读者将学会根据数据量和任务相似性选择合适的迁移学习策略,并掌握数据增强、特征解耦和预训练模型选择等进阶技巧,从而在数据不足的情况下有效训练模型。

迁移学习:站在巨人肩膀上

从头训一个 ResNet-50 在 CIFAR-10 上要几小时, 在 ImageNet 上要几周。但实际项目往往没那么多数据, 这时迁移学习就是救星。

核心思想

在大数据集 (ImageNet 1400 万图) 上预训练一个模型, 然后在小数据上微调 (fine-tune), 效果远好于从头训。

为什么有效:

  • 浅层学的是通用特征 (边缘 / 纹理 / 颜色), 跨任务通用
  • 深层学的是任务相关特征 (猫脸 / 车轮), 需要微调
  • 数据不够时, 预训练权重提供了很好的初始点

三种主流策略

1. 特征提取 (Feature Extraction)

冻结 backbone, 只训新分类头:

  • 把 backbone 输出的特征当成"高级特征向量"
  • 接一个简单的线性层做新分类
  • 适合: 数据极少 (< 1000 张), 与预训练任务相似
import torchvision.models as models
import torch.nn as nn

backbone = models.resnet50(weights="IMAGENET1K_V2")
for param in backbone.parameters():
    param.requires_grad = False  # 冻结

# 替换最后分类层
backbone.fc = nn.Linear(2048, num_classes)
# 只有 fc 层会被训练

2. 微调 (Fine-tuning)

backbone 也跟着训, 但用更小学习率:

  • 全模型解冻, 但 backbone 学习率比 head 小 10-100 倍
  • 适合: 数据中等 (1K-100K), 与预训练任务相似
backbone = models.resnet50(weights="IMAGENET1K_V2")
backbone.fc = nn.Linear(2048, num_classes)

optimizer = torch.optim.AdamW([
    {"params": backbone.fc.parameters(), "lr": 1e-3},
    {"params": backbone.parameters(), "lr": 1e-5},  # backbone 小 100 倍
])

3. 渐进解冻 (Progressive Unfreezing)

从后往前逐层解冻 (NLP 常用):

  • 先冻所有, 只训 head
  • 再解冻最后 1 层 conv
  • 再解冻 2 层, ... 直到全开
  • 适合: 数据极少, 想榨干最后一滴性能

关键决策: 冻结还是微调?

数据量与预训练相似推荐策略
< 1K相似特征提取
< 1K不相似特征提取 + 数据增强
1K-100K相似微调, 小学习率
1K-100K不相似微调 + 大量增强
> 100K任意从头训也行

实战:5 类花卉识别

假设有 5 类花卉, 每类 200 张图 (共 1000 张):

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_data = datasets.ImageFolder("flowers/train", train_transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

model = models.resnet50(weights="IMAGENET1K_V2")
model.fc = nn.Linear(2048, 5)  # 5 类

# 训练 10 epoch 通常就能 95%+ 准确率

进阶技巧

1. 数据增强 (Data Augmentation)

弥补数据不足:

  • 几何: RandomCrop / Flip / Rotate / Scale / Affine
  • 颜色: ColorJitter / RandomGrayscale / GaussianBlur
  • 高级: MixUp (两张图按比例混合) / CutMix (裁一块粘贴) / RandAugment (随机组合)

2. 特征解耦

冻结 backbone 不同部分:

  • BN 层冻结 (requires_grad=False), 防止微调破坏预训练统计量
  • 浅层 (前几层) 冻结, 只训深层

3. 预训练模型选择

  • ImageNet 预训练: 通用, 适合自然图像
  • CLIP 预训练: 跨模态, 适合零样本
  • DINOv2 (Meta 2023): 自监督, 特征极强
  • 领域专用: 医疗 (RadImageNet) / 卫星 (SSL4EO) / 工业 (MVTec)

常见坑

  1. 学习率太大: backbone 权重被破坏, 效果反而比冻结差
  2. 数据增强太弱: 微调容易过拟合
  3. 类别不平衡: 加 class_weight 或 Focal Loss
  4. 预训练数据分布差异: 自然图像 vs 医学图像, 需要更多微调
  5. 小 batch size: BN 不稳, 用 GroupNorm 或 SyncBN

总结

迁移学习 = 预训练 + 微调, 数据不够时的标准武器。现代 CV 项目 99% 用迁移学习, 从头训 ResNet 只在科研 baseline 时才做。

下一章目标检测, 我们看怎么让 CNN 不仅分类, 还能在图上框出物体位置

章末小测验

检验你对《迁移学习:站在巨人肩膀上》的掌握程度。

1

迁移学习中, '特征提取' 与 '微调' 的关键区别是?

2

微调 backbone 时学习率太大, 最可能发生什么?

讨论区(0)

加载评论中...