You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
21 lines
841 B
21 lines
841 B
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class LabelSmoothingCrossEntropy(nn.Module): |
|
def __init__(self, eps=0.1, reduction='mean',ignore_index=-100): |
|
super(LabelSmoothingCrossEntropy, self).__init__() |
|
self.eps = eps |
|
self.reduction = reduction |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, output, target): |
|
c = output.size()[-1] |
|
log_preds = F.log_softmax(output, dim=-1) |
|
if self.reduction=='sum': |
|
loss = -log_preds.sum() |
|
else: |
|
loss = -log_preds.sum(dim=-1) |
|
if self.reduction=='mean': |
|
loss = loss.mean() |
|
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction, |
|
ignore_index=self.ignore_index) |