2025-07-10 09:41:26 +08:00

88 lines
3.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceLoss(nn.Module):
def __init__(self, num_classes):
super(DiceLoss, self).__init__()
self.num_classes = num_classes
def forward(self, inputs, targets, smooth=1):
inputs = F.softmax(inputs, dim=1)
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
inputs = inputs.reshape(-1)
targets = targets.reshape(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
class DiceBCELoss(nn.Module):
def __init__(self, num_classes):
super(DiceBCELoss, self).__init__()
self.num_classes = num_classes
def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
inputs = inputs.reshape(-1)
targets = targets.reshape(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
bce = F.binary_cross_entropy(inputs, targets, reduction='mean')
return 1 - dice + bce
class IoULoss(nn.Module):
def __init__(self, num_classes):
super(IoULoss, self).__init__()
self.num_classes = num_classes
def forward(self, inputs, targets, smooth=1):
inputs = F.softmax(inputs, dim=1)
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
inputs = inputs.reshape(-1)
targets = targets.reshape(-1)
intersection = (inputs * targets).sum()
union = inputs.sum() + targets.sum() - intersection
iou = (intersection + smooth)/(union + smooth)
return 1 - iou
class FocalLoss(nn.Module):
def __init__(self, num_classes, alpha=0.25, gamma=2):
super(FocalLoss, self).__init__()
self.num_classes = num_classes
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
inputs = F.softmax(inputs, dim=1)
targets = F.one_hot(targets.long(), self.num_classes).permute(0, 3, 1, 2).float()
inputs = inputs.reshape(-1)
targets = targets.reshape(-1)
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
return focal_loss.mean()
class ce_loss(nn.Module):
def __init__(self, num_classes):
super(ce_loss, self).__init__()
self.num_classes = num_classes
super(ce_loss, self).__init__()
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets)
return ce_loss
def get_loss_function(loss_type, num_classes):
if loss_type == 'dice':
return DiceLoss()
elif loss_type == 'dice_bce':
return DiceBCELoss()
elif loss_type == 'iou':
return IoULoss()
elif loss_type == 'focal':
return FocalLoss()
elif loss_type == 'ce':
return ce_loss(num_classes)
else:
raise ValueError(f"Unknown loss type: {loss_type}")