You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
674 B
23 lines
674 B
1 year ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
class FocalLoss(nn.Module):
|
||
|
'''Multi-class Focal loss implementation'''
|
||
|
def __init__(self, gamma=2, weight=None,ignore_index=-100):
|
||
|
super(FocalLoss, self).__init__()
|
||
|
self.gamma = gamma
|
||
|
self.weight = weight
|
||
|
self.ignore_index=ignore_index
|
||
|
|
||
|
def forward(self, input, target):
|
||
|
"""
|
||
|
input: [N, C]
|
||
|
target: [N, ]
|
||
|
"""
|
||
|
logpt = F.log_softmax(input, dim=1)
|
||
|
pt = torch.exp(logpt)
|
||
|
logpt = (1-pt)**self.gamma * logpt
|
||
|
loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index)
|
||
|
return loss
|