LabelSmoothing介绍
背景
在多分类训练任务中,输入图片经过神级网络的计算,会得到当前输入图片对应于各个类别的置信度分数,这些分数会被softmax进行归一化处理,最终得到当前输入图片属于每个类别的概率。
对于分类问题,尤其是多类别分类问题中,常常把类别向量做成one-hot vector(独热向量)。one-hot vector 对应的向量便可表示为[0, 1, 0],即对于长度为n 的数组,只有一个元素是1,其余都为0。因此表征我们已知样本属于某一类别的概率是为1的确定事件,属于其他类别的概率则均为0。
one-hot 带来的问题:
对于损失函数,我们需要用预测概率去拟合真实概率,而拟合one-hot的真实概率函数会带来两个问题:
1)无法保证模型的泛化能力,容易造成过拟合;
- 全概率和0概率鼓励所属类别和其他类别之间的差距尽可能加大,而由梯度有界可知,这种情况很难adapt。会造成模型过于相信预测的类别。
标签平滑
label smoothing的提出就是为了解决上述问题。最早是在Inception v2中被提出,是一种正则化的策略。其通过"软化"传统的one-hot类型标签,使得在计算损失值时能够有效抑制过拟合现象。label smoothing相当于减少真实样本标签的类别在计算损失函数时的权重
实际上分了一点概率给其他类(均匀分),让标签没有那么绝对化,留给学习一点泛化的空间。
从而能够提升整体的效果。
pytorch 实现
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
class LabelSmoothSoftmaxCE(nn.Module):
def __init__(self,
lb_pos=0.9,
lb_neg=0.005,
reduction='mean',
lb_ignore=255,
):
super(LabelSmoothSoftmaxCE, self).__init__()
self.lb_pos = lb_pos
self.lb_neg = lb_neg
self.reduction = reduction
self.lb_ignore = lb_ignore
self.log_softmax = nn.LogSoftmax(1)
def forward(self, logits, label):
logs = self.log_softmax(logits)
ignore = label.data.cpu() == self.lb_ignore
n_valid = (ignore == 0).sum()
label = label.clone()
label[ignore] = 0
lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1)
label = self.lb_pos * lb_one_hot + self.lb_neg * (1-lb_one_hot)
loss = -torch.sum(logs*label, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
elif self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'none':
loss = loss
return loss
if __name__ == '__main__':
torch.manual_seed(15)
## 标签平滑损失函数计算
criteria = LabelSmoothSoftmaxCE(lb_pos=0.9, lb_neg=5e-3)
net1 = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),
)
net1.cuda()
net1.train()
net2 = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),
)
net2.cuda()
net2.train()
with torch.no_grad():
inten = torch.randn(2, 3, 5, 5).cuda()
lbs = torch.randint(0, 3, [2, 5, 5]).cuda()
lbs[1, 3, 4] = 255
lbs[1, 2, 3] = 255
print(lbs)
import torch.nn.functional as F
logits1 = net1(inten)
logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear', align_corners=True)
logits2 = net2(inten)
logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear', align_corners=True)
# loss1 = criteria1(logits1, lbs)
loss = criteria(logits1, lbs)
# print(loss.detach().cpu())
loss.backward()
print(loss)