import torch import torch.nn as nn import torch.nn.functional as F class DiceLoss(nn.Module): def __init__(self, num_classes): super(DiceLoss, self).__init__() self.num_classes = num_classes def forward(self, inputs, targets, smooth=1): inputs = F.softmax(inputs, dim=1) targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float() inputs = inputs.reshape(-1) targets = targets.reshape(-1) intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) return 1 - dice class DiceBCELoss(nn.Module): def __init__(self, num_classes): super(DiceBCELoss, self).__init__() self.num_classes = num_classes def forward(self, inputs, targets, smooth=1): inputs = F.sigmoid(inputs) targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float() inputs = inputs.reshape(-1) targets = targets.reshape(-1) intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) bce = F.binary_cross_entropy(inputs, targets, reduction='mean') return 1 - dice + bce class IoULoss(nn.Module): def __init__(self, num_classes): super(IoULoss, self).__init__() self.num_classes = num_classes def forward(self, inputs, targets, smooth=1): inputs = F.softmax(inputs, dim=1) targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float() inputs = inputs.reshape(-1) targets = targets.reshape(-1) intersection = (inputs * targets).sum() union = inputs.sum() + targets.sum() - intersection iou = (intersection + smooth)/(union + smooth) return 1 - iou class FocalLoss(nn.Module): def __init__(self, num_classes, alpha=0.25, gamma=2): super(FocalLoss, self).__init__() self.num_classes = num_classes self.alpha = alpha self.gamma = gamma def forward(self, inputs, targets): inputs = F.softmax(inputs, dim=1) targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float() inputs = inputs.reshape(-1) targets = targets.reshape(-1) ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss return focal_loss.mean() class ce_loss(nn.Module): def __init__(self, num_classes): super(ce_loss, self).__init__() self.num_classes = num_classes super(ce_loss, self).__init__() def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets) return ce_loss def get_loss_function(loss_type, num_classes): if loss_type == 'dice': return DiceLoss() elif loss_type == 'dice_bce': return DiceBCELoss() elif loss_type == 'iou': return IoULoss() elif loss_type == 'focal': return FocalLoss() elif loss_type == 'ce': return ce_loss(num_classes) else: raise ValueError(f"Unknown loss type: {loss_type}")