88 lines
3.2 KiB
Python
88 lines
3.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class DiceLoss(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(DiceLoss, self).__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def forward(self, inputs, targets, smooth=1):
|
|
inputs = F.softmax(inputs, dim=1)
|
|
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
|
|
inputs = inputs.reshape(-1)
|
|
targets = targets.reshape(-1)
|
|
intersection = (inputs * targets).sum()
|
|
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
|
|
return 1 - dice
|
|
|
|
class DiceBCELoss(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(DiceBCELoss, self).__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def forward(self, inputs, targets, smooth=1):
|
|
inputs = F.sigmoid(inputs)
|
|
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
|
|
inputs = inputs.reshape(-1)
|
|
targets = targets.reshape(-1)
|
|
intersection = (inputs * targets).sum()
|
|
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
|
|
bce = F.binary_cross_entropy(inputs, targets, reduction='mean')
|
|
return 1 - dice + bce
|
|
|
|
class IoULoss(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(IoULoss, self).__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def forward(self, inputs, targets, smooth=1):
|
|
inputs = F.softmax(inputs, dim=1)
|
|
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
|
|
inputs = inputs.reshape(-1)
|
|
targets = targets.reshape(-1)
|
|
intersection = (inputs * targets).sum()
|
|
union = inputs.sum() + targets.sum() - intersection
|
|
iou = (intersection + smooth)/(union + smooth)
|
|
return 1 - iou
|
|
|
|
class FocalLoss(nn.Module):
|
|
def __init__(self, num_classes, alpha=0.25, gamma=2):
|
|
super(FocalLoss, self).__init__()
|
|
self.num_classes = num_classes
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
|
|
def forward(self, inputs, targets):
|
|
inputs = F.softmax(inputs, dim=1)
|
|
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
|
|
inputs = inputs.reshape(-1)
|
|
targets = targets.reshape(-1)
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
|
pt = torch.exp(-ce_loss)
|
|
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
|
|
return focal_loss.mean()
|
|
|
|
class ce_loss(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(ce_loss, self).__init__()
|
|
self.num_classes = num_classes
|
|
super(ce_loss, self).__init__()
|
|
|
|
def forward(self, inputs, targets):
|
|
ce_loss = F.cross_entropy(inputs, targets)
|
|
return ce_loss
|
|
|
|
def get_loss_function(loss_type, num_classes):
|
|
if loss_type == 'dice':
|
|
return DiceLoss()
|
|
elif loss_type == 'dice_bce':
|
|
return DiceBCELoss()
|
|
elif loss_type == 'iou':
|
|
return IoULoss()
|
|
elif loss_type == 'focal':
|
|
return FocalLoss()
|
|
elif loss_type == 'ce':
|
|
return ce_loss(num_classes)
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {loss_type}") |