记录一些 PyTorch 的细节。
名不符实的损失函数
NLL Loss
nn.NLLLoss
,自称实现的是 Negative Log Likelihood Loss,理论上应该是:
ℓ(x,y)={ℓ1,…,ℓN}, where ℓn=−logxn,yn
但文档上同样也自称了,它实现的其实是:
ℓn=−xn,yn
是的 log 仅仅只存在于函数名里 (╯‵□′)╯︵╧═╧
Cross Entropy Loss
nn.CrossEntropyLoss
,理论上的交叉熵应该是:
L=−i∑Cyilogpi
其中 y 是实际类别,p 是预测类别。但文档上同样也说了,它实现的其实是 nn.LogSoftmax()
+ nn.NLLLoss()
,即:
L=−i∑Cyi⋅LogSoftmax(pi)
所以不能在用 nn.CrossEntropyLoss
前再手动 softmax 一次,不然就是两次 softmax 了,要出大问题。
不明觉厉的优化器实现
Momentum
Momentum 的介绍可以参考这里。首先,PyTorch 的 momentum 实现是:
vt=ρvt−1+∇L(θt)
θt+1=θt−r⋅vt
而不是:
vt=ρvt−1+r⋅∇L(θt)
θt+1=θt−vt
学习率是要跟整个动量相乘,而不是只乘梯度,Polyak’s Momentum 和 Nesterov’s Momentum 都是如此。
然后,Nesterov’s Momentum 的公式是:
vt=ρvt−1+∇L(θt−rρvt−1)
θt+1=θt−r⋅vt
它的思想是,先假设当前参数点 0 按上一次的动量多更新一步到点 1(下图的棕色箭头),然后在更新后的参数 θt′=θt−rρvt−1 上算梯度 ∇L(θt′)(红色箭头),用这个梯度来算这一次的动量 vt(绿色箭头),最后用这个动量来真正的更新当前参数点 0 到点 2。
可以看到,θ 和 ∇L(θ) 是不需要关注的,我们没有必要 0→1→0→2→3,我们可以直接把 θ′ 和 ∇L(θ′) 作为目标,即直接 1→2→3,1→2 是在更新参数,2→3 相当于是每一步都多更新一步,就不用再假设和回退了。
那么令 θt′=θt−rρvt−1,则有:
vt=ρvt−1+∇L(θt′)(1)
θt+1′=θt+1−rρvt=θt−r⋅vt−rρvt=θt′+rρvt−1−r⋅vt−rρvt=θt′−r∇L(θt′)−rρvt=θt′−r⋅(∇L(θt′)+ρvt)(2)
−r⋅∇L(θt′) 是 1→2,−rρvt 是 2→3。
包括 PyTorch 在内的深度学习框架的实现基本就是按照公式 (1) 和 (2) 来的,我把源码复制过来:
python
buf
是 vt,d_p
是 ∇L(θt′),alpha
是动量参数 ρ,lr
是学习率 r。
奇奇怪怪的初始化器
Kaiming Init
PyTorch 中,linear 层和 conv 层的默认 init 是 kaiming init:
Delving Deep into Rectifiers: Surpassing Human-level Performance on ImageNet Classification. Kaiming He, et al. ICCV 2015. [Paper]
但这两个地方都给了一个奇怪的参数:
nn/modules/linear.py
/ nn/modules/conv.py
python
a
这个参数看上去非常的奇怪,因为 PyTorch 在文档里说参数 a
“only used with leaky_relu”。
a
代表的是 Leaky ReLU 函数 x<0 部分的斜率 negative slop。在用 kaiming uniform init 时,PyTorch 会根据这个 negative slop 算一个放缩因子 gain
出来(文档):
gain=1+negative_slop22
然后 kaiming uniform 的边界为:
gain⋅fan_in3
如果用别的激活函数,就不该有 negative slop 这个参数,gain
也会用别的公式计算,但 PyTorch 就是给你搞了一个 a = sqrt(5)
的奇怪的默认值。
这个问题在这两个地方有解释:
PyTorch 的 init 进行过一次重构(pr #9038),重构前 linear 和 conv 的默认 init 是:
python
重构之后才开始用 kaiming init。但为了保证向后兼容,他们希望重构前后的默认 init 的输出是等价的。重构前的均匀分布边界是(self.weight.size(1)
就是 fan_in
(输入节点数量)):
fan_in1
因此为了让重构前后的边界等价:
fan_in1=gain⋅fan_in3=fan_in⋅(1+negative_slop2)6
⇒negative_slop=5
所以这个 a = sqrt(5)
的奇怪的默认值就是这样来的,不是什么推荐值,只是为了保证向后兼容而强行设的而已…
为什么我写了那么大一段话来解释这个无聊的结论…