前言
上一篇文章学习了策略梯度,它是同策略,采样学习过一次的数据在执行梯度上升之后就不能再用了,需要重新采样,所以它在采样上花费了大量的时间。而近端策略优化解决了这个问题,它是策略梯度的变种,也是异策略,它用另外一个策略和演员同环境交互,让原来的策略去学习另外一个策略,这样可以多次使用另外一个策略采样到的数据,可以多次执行梯度上升。
关键词
- 同策略和异策略:如果要学习的智能体和与环境交互的智能体是相同的,就称之为同策略,反之是异策略。
- 重要性采样:假如我们要采样的概率分布P(x)不方便直接采样,可以从另一个已知的分布Q(x)中采样,这个分布通常是容易采样的,并且尽可能与P(x)相似。然后通过对抽样值应用权重来调整相应函数值的贡献,以更准确地估计目标分布P(x)下的期望。
- 重要性权重:用来修正两个分布的差异,一般记作q(x)p(x)。
同策略变异策略
我们可以使用重要性采样把同策略换成异策略。
重要性采样是使用另一种分布,逼近所求分布的一种方法,算是一种期望修正的方法。
∇Rθ¯=Eτpθ(τ)[R(τ)∇logpθ(τ)]
假设分布p不能做积分,但是可以部分采样。对于随机变量x,设f(x)是期望函数,对其做平均值,可以近似得到期望。
Exp[f(x)]≈N1i=1∑Nf(xi)
假设分布p无法采样,要从另一个分布q采样,因为期望Exp[f(x)]就是∫f(x)p(x)dx可以做如下变换:
∫f(x)p(x)dx=∫f(x)q(x)p(x)q(x)dx=Exq[f(x)q(x)p(x)]
Exp[f(x)]=Exq[f(x)q(x)p(x)]
同理用参数θ′代替参数θ。
∇Rθ¯=Eτpθ′(τ)[pθ′(τ)pθ(τ)R(τ)∇logpθ(τ)]
带入策略梯度公式。
E(st,at)πθ′[pθ′(st,at)pθ(st,at)Aθ(st,at)∇logpθ(atn∣stn)]
根据条件概率公式。
pθ(st,sa)=pθ(at∣st)pθ(st)
E(st,at)πθ′[pθ′(at∣st)pθ′(st)pθ(at∣st)pθ(st)Aθ(st,at)∇logpθ(atn∣stn)]
我们假设pθ(st)和pθ′(st)相同,有两个原因,一是因为状态往往与采取的动作是没有太大关系的,比如我们玩拳皇,无论采取哪个动作,我们看到的游戏画面都是差不多的,二是因为这个数值很难计算,所以干脆无视这个问题。而pθ(at,st)很好算,因为参数θ是一个神经网络,输入状态,它会输出每个动作的概率。
E(st,at)πθ′[pθ′(at∣st)pθ(at∣st)Aθ(st,at)∇logpθ(atn∣stn)]
根据log求导公式,反推目标函数。
f(x)∇f(x)=∇logf(x)
Jθ′(θ)=E(st,at)πθ′[pθ′(at∣st)pθ(at∣st)Aθ′(st,at)]
重要性采样存在一个问题,如果两个分布相差太多,结果就会不好。因为期望相同,但是方差不同,而且分布相差越多,方差差距越大。
我们把f(x)和f(x)q(x)p(x)带入方差公式。
Varxp[f(x)]=Exp[f(x)2]−(Exp[f(x)])2
Varxq[f(x)q(x)p(x)]=Exq[(f(x)q(x)p(x))2]−(Exq[f(x)q(x)p(x)])2=∫(f(x)q(x)p(x))2q(x)dx−(Exp[f(x)])2=∫f(x)2q(x)p(x)p(x)dx−(Exp[f(x)])2=Exp[f(x)2q(x)p(x)]−(Exp[f(x)])2
两个分布的方差公式的差别只在第一项,如果q(x)p(x)差距很大,方差就会很大。如果采样次数不够多,结果差别就会很大。
近端策略优化
近端策略优化解决了重要性采样的问题,它引入了一个约束,使两个分布差距不会太大。这个约束是θ和θ′输出的动作的KL散度(KL divergence),它用于衡量两者的相似程度。所以PPO包含了两个优化目标,一个是目标函数Jθ′(θ),另一个是约束。
JPPOθ′(θ)=Jθ′(θ)−βKL(θ,θ′)Jθ′(θ)=E(st,at)πθ′[pθ′(at∣st)pθ(at∣st)Aθ′(st,at)]
KL散度指的是行为上的距离,并不是参数上的距离,因为参数的变化和行为的变化不一致,可能参数变了一点,输出的动作变了很多。
近端策略优化惩罚
公式同上,其中的权重系数β是一个自适应系数,会动态调整。我们会设置KL散度的最大值和最小值,如果KL散度太大,说明动作差距太大,需要增强惩罚力度,增大β,如果KL散度太小,说明动作差距很小,为了避免两个参数一样,需要减小惩罚力度,减小β。
近端策略优化剪裁
剪裁的目标函数里没有KL散度,其要最大化的目标函数为
JPPO2θk≈(st,at)∑min(pθk(at∣st)pθ(at∣st)Aθk(st,at),clip(pθk(at∣st)pθ(at∣st),1−ε,1+ε)Aθk(st,at))
剪裁(clip)函数是指,括号中有三项,如果第一项小于第二项,那就输出1−ε,第一项如果大于第三项,那就输出1+ε。ε是一个超参数,是我们要调整的。
如果A>0,表示某个状态-动作对是好的,那么我们要增大pθ(at∣st),假如它和pθk(at∣st)的比值超过1+ε,会被限制在1+ε。
如果A<0,表示某个状态-动作对不好,那么我们要减小pθ(at∣st),假如它和pθk(at∣st)的比值低于1−ε,会被限制在1+ε。
所以剪裁公式保证了pθ(at∣st)和pθk(at∣st)不会差距太大。而且在代码实现上剪裁比惩罚要容易的多。
代码实现
PPO (opens new window)