最近在尝试复现ppo,在求loss的时候不清楚entropy loss该具体怎么实现,特来求助。
本人代码在:https://github.com/LeonShangguan/ppo_car_racing

目前不加entropy loss的情况下也能跑通,对car-racing的reward能达到850左右,但存在一些问题:

  1. 训练过程中reward增加至最高点时会突然下降,随后再次重新增加。
  2. 换个环境就不行了,哪怕很简单的pendulum swingup都训不出来。
    怀疑是entropy loss的问题,恳请大神帮忙解答下。

    LeonShangguan

    从stable-baseline3的源代码来看主要是

    if entropy is not  None:
        entropy_loss = -th.mean(entropy)

    进一步找代码可以看到,实际上返回的entropy是

            distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
            return distribution.entropy()

    注意,其中的latent_pi和latent_sde是指:

    也就是两个变量之间的熵

    def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution:
            """
            Retrieve action distribution given the latent codes.
    
            :param latent_pi: Latent code for the actor
            :param latent_sde: Latent code for the gSDE exploration function
            :return: Action distribution
            """
            mean_actions = self.action_net(latent_pi)
    
            if isinstance(self.action_dist, DiagGaussianDistribution):
                return self.action_dist.proba_distribution(mean_actions, self.log_std)
            elif isinstance(self.action_dist, CategoricalDistribution):
                # Here mean_actions are the logits before the softmax
                return self.action_dist.proba_distribution(action_logits=mean_actions)
            elif isinstance(self.action_dist, MultiCategoricalDistribution):
                # Here mean_actions are the flattened logits
                return self.action_dist.proba_distribution(action_logits=mean_actions)
            elif isinstance(self.action_dist, BernoulliDistribution):
                # Here mean_actions are the logits (before rounding to get the binary actions)
                return self.action_dist.proba_distribution(action_logits=mean_actions)
            elif isinstance(self.action_dist, StateDependentNoiseDistribution):
                return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
            else:
                raise ValueError("Invalid action distribution")

      PonyShan 谢谢。我刚也看了下源码,我的理解是这个entropy就是用来sample action 的那个distribution的entropy。应该对的吧?

      说点什么吧...
      Document