在读《Dive into Deep Learning》关于 Weight Decay 的章节时,我有很多疑问,思索查询后整理成了这篇博文。希望能帮助到有同样疑问的人。
原文在这里:Weight Decay
先摆出一个重要的判断:越复杂的 model 越容易 overfit。
一个 model,或者说 function,有几个影响复杂度的属性:
- features 的数量:即 X 的列数。但这个是 data set 自身的属性,不由 function 决定(当然,我们可以在 function 之外,也就是在数据收集阶段改变它。但这里只聚焦 function)。
- weights 的取值:其中每个 weight 的大小会决定 model 对相应 feature 的敏感程度 —— 越大的 weight 会让 model 对相应的 feature 更敏感,因为两者的计算是乘法。越小的 weight 会让 model 对相应的 feature 越迟钝。
- bias 的取值:它的影响极小。
- 以上三者组成的多项式:主要是次数(degree)和项数的选择。即便仅有一个 feature 也可以组成很多种多项式,比如 $wx^2 + wx + b$,或 $wx^{15} + b$。
到此可以发现,如果从 1 和 3 着手,很难应对 overfit 。于是重点考虑 2 和 4。但 4 很快就被排除了,因为很容易导致变化波动特别大 —— 稍微调整一下次数或多项式的项数,都会让 function 的 output,也就是 y_predict,变化剧烈。
因此,只剩下 3,也就是从 weights 着手。有两个殊途同归的考虑角度:
- 如前所述,如果 weights 中的部份值很大,那么 model 会对那部份对应的 feature 特别敏感,这种敏感会让 model 更容易 overfit。那么我们就让 weights 的值更均衡,避免一部份值很大,另一部份很小。
- 越复杂的 model 越容易 overfit。优化的前提是可衡量,衡量 model 复杂度的方法之一是 norm。其中 L2 norm 正好能识别上一条所描述的 weights 值不均衡的现象。
于是,就在 loss 函数里加上对 L2 norm 的惩罚。
总结:缓解 overfitting 的关键是降低 model 的复杂度。影响复杂度的几个关键属性中,最易着手的、最易控制力度的是 weights。这是 Weight Decay 之所以能缓解 overfitting 的大致原理。