本章通过训练 CartPole-v1 环境中的 DQN agent,展示了强化学习的基本流程和核心概念。核心内容包括:环境配置、状态和动作空间理解、DQN 算法实现、训练过程及结果可视化。读者将学习如何设置强化学习任务,包括处理 4 维连续状态空间和 2 个离散动作,以及如何应用经验回放、目标网络和神经网络等 DQN 的关键组件。学完后,读者能够独立训练一个 CartPole agent,使其达到平均奖励 ≥ 475 的目标,并掌握调整超参数(如学习率、gamma 值、epsilon 值等)以优化训练效果的方法。此外,本章还介绍了让训练更稳定的 7 个技巧,如奖励缩放、梯度裁剪和 Huber Loss 等。最后,读者将了解如何保存和加载模型,并尝试在更复杂的任务(如 LunarLander)中应用更高级的算法(如 DDPG、PPO 和 SAC)。
实战:CartPole 训练
CartPole 是强化学习的"Hello World"——简单到几分钟能跑通,但又能体现 RL 的所有核心思想。这一章带你端到端训练一个 DQN agent。
项目目标
- 环境:CartPole-v1(OpenAI Gym)
- 状态:4 维连续向量(小车位置、速度、杆子角度、角速度)
- 动作:2 个离散(0=左推, 1=右推)
- 奖励:每平衡 1 步 +1
- 目标:平均奖励 ≥ 475(满分 500)
预计训练时间:5-15 分钟(GPU)/ 10-20 分钟(CPU)
第一步:环境准备
pip install torch gymnasium numpy matplotlib
# 注: Gymnasium 是 Gym 的现代维护版,API 一样
第二步:理解环境
import gymnasium as gym
env = gym.make("CartPole-v1")
print(f"状态空间: {env.observation_space}") # Box(4,)
print(f"动作空间: {env.action_space}") # Discrete(2)
print(f"奖励范围: {env.reward_range}") # (-inf, inf)
# 跑一局
state, _ = env.reset()
print(f"初始状态: {state}")
total_reward = 0
done = False
while not done:
action = env.action_space.sample() # 随机动作
state, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
print(f"随机策略总奖励: {total_reward}")
第三步:实现 DQN
把上一章的 DQN 实现搬过来,加一点针对 CartPole 的优化:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import gymnasium as gym
class QNetwork(nn.Module):
def __init__(self, state_dim=4, n_actions=2, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions)
)
def forward(self, x):
return self.net(x)
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, s, a, r, s_, done):
self.buffer.append((s, a, r, s_, done))
def sample(self, batch_size=64):
batch = random.sample(self.buffer, batch_size)
s, a, r, s_, done = zip(*batch)
return (
torch.tensor(np.array(s), dtype=torch.float32),
torch.tensor(a, dtype=torch.long),
torch.tensor(r, dtype=torch.float32),
torch.tensor(np.array(s_), dtype=torch.float32),
torch.tensor(done, dtype=torch.float32)
)
def __len__(self):
return len(self.buffer)
class DQNAgent:
def __init__(self, state_dim=4, n_actions=2):
self.n_actions = n_actions
self.gamma = 0.99
self.batch_size = 64
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.target_update = 10
self.q_net = QNetwork(state_dim, n_actions)
self.target_net = QNetwork(state_dim, n_actions)
self.target_net.load_state_dict(self.q_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-3)
self.buffer = ReplayBuffer(capacity=10000)
self.steps = 0
def select_action(self, state):
if random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad():
q = self.q_net(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
return q.argmax().item()
def train_step(self):
if len(self.buffer) < self.batch_size:
return None
s, a, r, s_, done = self.buffer.sample(self.batch_size)
q_pred = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
with torch.no_grad():
q_next = self.target_net(s_).max(1)[0]
q_target = r + self.gamma * q_next * (1 - done)
loss = nn.functional.mse_loss(q_pred, q_target)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0)
self.optimizer.step()
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
self.steps += 1
if self.steps % self.target_update == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
return loss.item()
第四步:训练
env = gym.make("CartPole-v1")
agent = DQNAgent(state_dim=4, n_actions=2)
episodes = 500
rewards_history = []
for ep in range(episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.buffer.push(state, action, reward, next_state, done)
agent.train_step()
state = next_state
total_reward += reward
rewards_history.append(total_reward)
# 打印进度
if (ep + 1) % 20 == 0:
recent = rewards_history[-20:]
avg = np.mean(recent)
print(f"Episode {ep+1:4d} | "
f"Reward: {total_reward:3.0f} | "
f"Avg(20): {avg:6.1f} | "
f"ε: {agent.epsilon:.3f}")
预期输出:
Episode 20 | Reward: 18 | Avg(20): 14.5 | ε: 0.905
Episode 40 | Reward: 32 | Avg(20): 21.3 | ε: 0.819
Episode 60 | Reward: 45 | Avg(20): 38.7 | ε: 0.741
Episode 80 | Reward: 78 | Avg(20): 65.2 | ε: 0.671
Episode 100 | Reward: 120 | Avg(20): 103.5 | ε: 0.607
Episode 200 | Reward: 500 | Avg(20): 485.0 | ε: 0.299
Episode 300 | Reward: 500 | Avg(20): 500.0 | ε: 0.149
大概 150-250 轮就能稳到 500 分(满分)。
第五步:可视化训练曲线
import matplotlib.pyplot as plt
def smooth(values, window=20):
return [np.mean(values[max(0, i-window):i+1]) for i in range(len(values))]
plt.figure(figsize=(10, 5))
plt.plot(rewards_history, alpha=0.3, label='每轮奖励')
plt.plot(smooth(rewards_history, 20), label='滑动平均(20 轮)', linewidth=2)
plt.axhline(475, color='r', linestyle='--', label='目标线 (475)')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN on CartPole')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig('cartpole_training.png', dpi=120)
plt.show()
第六步:看 agent 玩耍
import time
env = gym.make("CartPole-v1", render_mode="human")
state, _ = env.reset()
done = False
while not done:
action = agent.select_action(state)
state, _, terminated, truncated, _ = env.step(action)
done = terminated or truncated
time.sleep(0.02)
env.close()
第七步:保存和加载模型
# 保存
torch.save({
'q_net': agent.q_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'epsilon': agent.epsilon
}, 'cartpole_dqn.pth')
# 加载
checkpoint = torch.load('cartpole_dqn.pth')
agent.q_net.load_state_dict(checkpoint['q_net'])
agent.target_net.load_state_dict(checkpoint['target_net'])
agent.epsilon = checkpoint['epsilon']
agent.epsilon = 0 # 测试时关掉探索
调参进阶
跑通之后,可以试试这些改进看能不能更快收敛:
# 1. Double DQN
def train_step_double_dqn(self):
s, a, r, s_, done = self.buffer.sample(self.batch_size)
q_pred = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
with torch.no_grad():
# 用 Q 网络选动作, target 网络估 Q
best_actions = self.q_net(s_).argmax(1).unsqueeze(1)
q_next = self.target_net(s_).gather(1, best_actions).squeeze()
q_target = r + self.gamma * q_next * (1 - done)
loss = nn.functional.mse_loss(q_pred, q_target)
# ... 同前
# 2. Dueling DQN
class DuelingQNetwork(nn.Module):
def __init__(self, state_dim, n_actions, hidden=128):
super().__init__()
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU()
)
self.value = nn.Linear(hidden, 1) # V(s)
self.advantage = nn.Linear(hidden, n_actions) # A(s, a)
def forward(self, x):
f = self.feature(x)
v = self.value(f)
a = self.advantage(f)
# Q = V + (A - mean(A))
return v + a - a.mean(dim=1, keepdim=True)
实战技巧汇总
进阶:换更难的环境
CartPole 只是开胃菜,试着挑战更难的:
# 1. Acrobot - 双节摆, 目标甩到顶
env = gym.make("Acrobot-v1")
# 2. MountainCar - 小车要爬上山
env = gym.make("MountainCar-v0")
# 3. LunarLander - 登月(连续动作)
env = gym.make("LunarLander-v2") # 需要 PPO/SAC
LunarLander 用了连续动作,DQN 搞不定,需要 DDPG/PPO/SAC。
小结
- CartPole 4 维状态 + 2 离散动作,完美的 DQN 入门
- 经验回放 + 目标网络 + 神经网络 = DQN 三件套
- 100-200 轮基本能训到满分
- 调参:lr、gamma、epsilon、buffer size、target_update
- 想玩更难的环境,得换算法(DDPG/PPO/SAC)
练习思考
- 把 γ 改成 0.5 和 0.999,分别跑一下,看哪个更慢/更不稳。
- 不开经验回放(buffer size = 1),能训出来吗?为什么?
- 试试换成 Acrobot 环境,改改超参能不能训出来?
章末小测验
检验你对《实战:CartPole 训练》的掌握程度。
以下关于 CartPole 环境的描述,哪些是正确的?
关于 DQN 的核心组件,以下哪些说法是正确的?
以下哪些环境需要使用 DDPG/PPO/SAC 等算法?
在 CartPole 环境中,以下哪些调参选项是常见的?
关于 CartPole 训练,以下哪些说法是正确的?