DDPG 作为一种比较成功的 AC 类算法,已经在很多任务中取得了不俗的表现。但它有一个致命的短板,就是 overestimation 的问题,容易造成 policy 收敛困难。最典型的就是在 gym 的 BipedalWalker 中,如果采用原始的 DDPG 算法,是很难 solve 的,你会发现你的总奖励值始终卡在很低的水平,根本无法进步。这里的主要原因就是对 Q value 的过高估计,导致 agent 一直在糟糕的策略下挣扎,无法找到有效的策略。
为了解决这个问题,TD3 横空出世。它其实就是 DDPG 的升级版,网络架构跟原版基本类似,但是它参考了 Double DQN 里面双 Q 网络的设计,采用了两个 Critic 网络,在计算 Q 值估计时,会让两个网络分别去预测,然后取最小的那个值。这里其实就是经典的双学习思想,具体的原理在 Sutton 的书里有详细讲解。
核心设计
- Twin :双 Critic 网络
DDPG 在训练过程中容易出现 over estimation 的问题,即对 Q value 的估计偏高,而且这种 bias 会随时间积累,导致算法无法收敛。
This is caused by the algorithm continuously over estimating the Q values of the critic (value) network. These estimation errors build up over time and can lead to the agent falling into a local optima or experience catastrophic forgetting
针对这个问题,提出了 DDPG 的改进版本,即TD3 ,它参考了 DQN 中 Double Q-learning 的设计,在原有的架构上使用了两个 Critic 网络。
而 DDPG 原有的 target 网络虽然也是从 DQN 的设计中演化来的,但 target 网络对于解决 over estimation 并不理想。
This is because the policy and target networks are updated so slowly that they look very similar
但为了稳定更新的目的,Target 网络还是需要保留,所以 TD3 实际上采用了两套 Critic 网络,Local 网络和Target 网络都包含两个独立的 Q network。
每次更新,取两个 Q target 网络中对 Q value 估计值较小的那个作为 Q 目标
Q_{target} = min [ Q_{\theta'1}(s',a'),Q_{\theta'2}(s',a') ]
critic 的总 loss 计算: 把两个 local 网络的 MSE 误差相加
Loss = MSE(Q_{\theta1},Q_{target}) + MSE(Q_{\theta2},Q_{target})
Delayed:延迟更新策略
在 Actor-Critic 架构中使用 Target 网络,虽然能起到稳定训练的作用,但也可能会引起一个严重的问题。如果一个糟糕的策略被错误的过高估计 (overestimated) ,那么这个错误就会持续累积,导致策略网络无法收敛 diverges。
解决办法就是对 Policy 使用延迟更新的方法,让 Policy 更新的频率比 Critic 更低,这样就可以让 Critic 有足够的时间稳定 Q 函数并纠正错误,避免错误的估计影响策略更新。
TD3 uses a delayed update of the actor network, only updating it every 2 time steps instead of after each time step, resulting in more stable and efficient training。
每个 time step,Critic 都更新一次;而 Actor 则是每隔 N 个 time step 才更新一次
- Noise Regularisation
确定性策略有个缺陷,在更新 Critic 的时候容易产生高方差的 Q value 估计。
This is caused by overfitting to spikes in the value estimate.
TD3 采用了一种正则化方法 ,目标策略平滑(target policy smoothing).
在计算 Q 目标时,给 action 加入一个小的随机噪音,并采用截断 (clipped noise)来避免过多的偏离原始值,经过多个 mini batch 的累加,这些噪声会取得近似的平均。
加入了噪声后,对于那些更 robust 的动作(对于噪声和干扰的抵抗力更强),Critic 就能给出更高的 Q value,从而让 policy 更稳定。
算法流程
TD3 伪代码
(github 相关实践代码看这里⬇️)
https://github.com/Quantum-Cheese/DeepReinforcementLearning_Pytorch/tree/master/DDPGs/TD3