使用 PyTorch 实现 DQN 算法,控制 AI 在 Atari 游戏 Pong 中的表现。通过经验回放与目标网络实现稳定训练。
存储历史交互经验,随机采样打破样本相关性,显著提升训练稳定性与数据利用效率。
独立的目标网络用于计算 Q 值目标,定期同步参数,避免训练过程中的发散问题。
动态调整探索率,在训练初期大范围探索环境,后期逐渐转向利用已学到的最优策略。
# DQN核心更新逻辑
def update(self, state, action, reward, next_state, done):
# 1. 计算目标Q值
target = reward
if not done:
target += self.gamma * torch.max(self.q_network(next_state))
# 2. 计算当前Q值
current = self.q_network(state)[action]
# 3. 优化损失
loss = self.criterion(current, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()