什么是 Label Smoothing Cross Entropy?
Label Smoothing 是一种正则化技术,用于改进分类任务中的交叉熵损失函数。传统的交叉熵损失函数假设目标标签是硬性(hard)的,即每个样本只有一个正确的类别标签,并且该类别的概率为 1,其他类别的概率为 0。然而,这种硬性标签可能会导致模型过拟合训练数据,尤其是在训练数据有限或标签可能存在噪声的情况下。
Label Smoothing 的基本思想是对目标标签进行“平滑”处理,将原本硬性的标签分布替换为一个更柔和的分布。这样可以减少模型对单一类别的过度自信,从而提高模型的泛化能力。
Label Smoothing 的工作原理
1. 传统交叉熵损失
在分类问题中,交叉熵损失函数定义如下:
其中:
- C 是类别总数。
- $y_i $是目标标签的 one-hot 编码(硬性标签),即正确类别的值为 1,其他类别的值为 0。
- 是模型预测的第 i 类的概率。
在硬性标签的情况下,模型会努力最大化正确类别的概率 ,而完全忽略其他类别的概率。
2. Label Smoothing 的引入
Label Smoothing 将目标标签从硬性分布转换为软性分布,具体公式如下:
其中:
- ϵ 是平滑参数,通常取值在 [0, 1] 范围内(例如 0.1)。
- 是平滑后的目标标签分布。
- 正确类别的概率被降低为 1−ϵ,而其他类别的概率被提升为。
3. 平滑后的交叉熵损失
使用平滑后的标签分布,交叉熵损失变为:
展开后可以写为:
其中:
- 是模型对正确类别的预测概率。
- 是对其他类别的惩罚项。
通过这种方式,模型不仅需要最大化正确类别的概率,还需要关注其他类别的预测结果,从而避免对单一类别的过度自信。
Label Smoothing 的优点
- 减少过拟合 Label Smoothing 防止模型对训练数据中的硬性标签过于依赖,从而提高了模型的泛化能力。
- 改善模型的校准 使用 Label Smoothing 后,模型的预测概率通常更加接近真实分布,而不是过度集中在某个类别上。
- 缓解标签噪声的影响 如果训练数据中的标签存在噪声,Label Smoothing 可以通过平滑标签分布来降低噪声对模型的影响。
- 增强模型的鲁棒性 在对抗攻击等场景下,Label Smoothing 可以使模型对输入扰动更加鲁棒。
Label Smoothing 的实现
以下是一个基于 PyTorch 的 Label Smoothing 实现示例:
1 | import torch |
- PyTorch 1.10之后CrossEntropyLoss 已经原生支持标签平滑功能
1 | import torch |
Label Smoothing 的注意事项
- 选择合适的平滑参数 ϵ 的值通常在 0.1 左右。如果 ϵ 过大,可能会导致模型对正确类别的学习不足;如果 ϵ 过小,则效果可能不明显。
- 适用于大规模分类任务 Label Smoothing 在类别数量较多的任务中效果更显著,因为平滑后的分布能够更好地反映类别间的关联性。
- 与知识蒸馏结合 Label Smoothing 常与知识蒸馏(Knowledge Distillation)结合使用。通过使用教师模型生成软标签,学生模型可以学习到更加丰富的类别间关系。