Hongbo Mao bio photo

Email

Github

标签平滑(Label Smoothing)是一种正则化技术,常用于深度学习模型的训练中,特别是在分类任务中。它的主要目的是防止模型过拟合,提升模型的泛化能力。标签平滑通过调整目标标签的概率分布,使得模型在训练过程中不会对训练数据产生过度自信。

标签平滑的基本思想

在标准的分类任务中,目标标签通常使用one-hot编码,例如,如果有三个类别 \(C1, C2, C3\),并且实际标签是 \(C2\),则目标标签会表示为 \([0, 1, 0]\)。这意味着模型应当完全确信样本属于\(C2\)类。然而,这种极端的标签可能导致模型过度自信,从而导致过拟合。

标签平滑通过将 one-hot 编码的标签分布调整为一个平滑的概率分布。例如,使用标签平滑的目标标签可能变成 \([0.1, 0.8, 0.1]\),其中每个标签的概率都不再是0或1,而是一个较小的值。这种方法在计算交叉熵损失时,通过降低正确类别的权重并增加错误类别的权重,从而防止模型对某一类别过度自信。

标签平滑的公式

假设有\(N\)个类别,标准 one-hot 编码的标签为 \(y_{true}\),标签平滑后的目标标签为\(y_{smooth}\)。标签平滑的计算公式如下:

\(y_{smooth} = (1 - \epsilon) \cdot y_{true} + \frac{\epsilon}{N}\) 其中:

  • \(epsilon\)是标签平滑的系数,取值范围为\([0, 1]\)。通常,\(\epsilon\)是一个小的正数。
  • \(N\)是类别的数量。
  • \(y_{true}\)是 one-hot 编码的标签。
  • \(\frac{\epsilon}{N}\)表示为每个类别分配的平滑值。

代码示例

以下是使用 PyTorch 实现标签平滑的代码示例:

import torch
import torch.nn.functional as F

IGNORE_ID = -1

def cal_loss(pred, gold, smoothing=0.0):
    """Calculate cross entropy loss, apply label smoothing if needed."""

    if smoothing > 0.0:
        eps = smoothing
        n_class = pred.size(1)

        # Generate one-hot matrix: N x C.
        gold_for_scatter = gold.ne(IGNORE_ID).long() * gold
        one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class
        log_prb = F.log_softmax(pred, dim=1)

        non_pad_mask = gold.ne(IGNORE_ID)
        n_word = non_pad_mask.sum().item()
        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum() / n_word
    else:
        loss = F.cross_entropy(pred, gold, ignore_index=IGNORE_ID, reduction='mean')

    return loss

# 示例输入
pred = torch.tensor([[2.0, 1.0, 0.1], [1.0, 2.0, 0.1]])
gold = torch.tensor([1, 2])

# 使用标签平滑计算损失
smoothing = 0.1
loss = cal_loss(pred, gold, smoothing)
print(loss)

标签平滑的优点

  1. 防止过拟合:通过平滑目标标签,减少模型对训练数据的过度自信,从而提高模型的泛化能力。
  2. 提高模型的鲁棒性:标签平滑使模型在遇到噪声数据或未知类别时更具鲁棒性。
  3. 改进梯度传播:平滑标签有助于在训练初期改进梯度的传播,特别是在有大量类别的情况下。

总结

标签平滑是一种简单但有效的正则化技术,通过调整目标标签的概率分布,帮助模型在训练过程中防止过拟合,提升泛化能力。通过在计算损失时加入标签平滑,可以使得模型对各类别的预测更加平衡,从而提高模型的鲁棒性。