I am doing a NER task with BERT, but suffer from the sample inbalance problem. So I wanna use focal loss to sovle it. However, as a new, I don't know what problem of my code shown below, where pred
is the output of the bert, whose shape is (batch_size, max_seq_length, num_tags)
, and the shape of the target
is (batch_size, max_seq_length)
. When I run the code, there is an error return from logpt = torch.gather(log_softmax, dim=-1, index=target)
.
Here is my Focal loss code.
class FocalLoss(nn.Module):
def __init__(self, weight: list, gamma=2.0, reduction="mean"):
super().__init__()
self.gamma = gamma
self.alpha = torch.tensor(weight)
self.reduction = reduction
def forward(self, pred, target):
alpha = self.alpha[target]
log_softmax = torch.log_softmax(pred, dim=-1)
logpt = torch.gather(log_softmax, dim=-1, index=target)
logpt = logpt.view(-1)
ce_loss = -logpt
pt = torch.exp(logpt)
focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == "mean":
return torch.mean(focal_loss)
if self.reduction == "sum":
return torch.sum(focal_loss)
return focal_loss
So, my confusion is how to modify this code.