138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
def calculate_iou(pred, target, num_classes):
|
|
"""计算IoU (Intersection over Union)
|
|
|
|
Args:
|
|
pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits或 (B, H, W) 的类别索引
|
|
target (torch.Tensor): 真实标签,形状为 (B, H, W)
|
|
num_classes (int): 类别数量
|
|
|
|
Returns:
|
|
dict: 每个类别的IoU和平均IoU
|
|
"""
|
|
if pred.dim() == 4: # (B, C, H, W) logits
|
|
pred = torch.argmax(pred, dim=1) # (B, H, W)
|
|
|
|
# 将张量移动到CPU并转换为NumPy数组
|
|
pred = pred.detach().cpu().numpy()
|
|
target = target.detach().cpu().numpy()
|
|
|
|
# 初始化结果
|
|
iou_per_class = np.zeros(num_classes)
|
|
|
|
# 计算每个类别的IoU
|
|
for cls in range(num_classes):
|
|
pred_cls = (pred == cls).astype(np.int8)
|
|
target_cls = (target == cls).astype(np.int8)
|
|
|
|
intersection = np.sum(pred_cls & target_cls)
|
|
union = np.sum(pred_cls | target_cls)
|
|
|
|
# 避免除零错误
|
|
if union == 0:
|
|
iou_per_class[cls] = 0.0
|
|
else:
|
|
iou_per_class[cls] = intersection / union
|
|
|
|
# 计算平均IoU (mIoU)
|
|
miou = np.mean(iou_per_class)
|
|
|
|
# 构建结果字典
|
|
result = {'mIoU': miou}
|
|
for cls in range(num_classes):
|
|
result[f'IoU_class_{cls}'] = iou_per_class[cls]
|
|
|
|
return result
|
|
|
|
def calculate_dice(pred, target, num_classes):
|
|
"""计算Dice系数
|
|
|
|
Args:
|
|
pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits或 (B, H, W) 的类别索引
|
|
target (torch.Tensor): 真实标签,形状为 (B, H, W)
|
|
num_classes (int): 类别数量
|
|
|
|
Returns:
|
|
dict: 每个类别的Dice系数和平均Dice系数
|
|
"""
|
|
if pred.dim() == 4: # (B, C, H, W) logits
|
|
pred = torch.argmax(pred, dim=1) # (B, H, W)
|
|
|
|
# 将张量移动到CPU并转换为NumPy数组
|
|
pred = pred.detach().cpu().numpy()
|
|
target = target.detach().cpu().numpy()
|
|
|
|
# 初始化结果
|
|
dice_per_class = np.zeros(num_classes)
|
|
|
|
# 计算每个类别的Dice系数
|
|
for cls in range(num_classes):
|
|
pred_cls = (pred == cls).astype(np.int8)
|
|
target_cls = (target == cls).astype(np.int8)
|
|
|
|
intersection = 2.0 * np.sum(pred_cls & target_cls)
|
|
sum_areas = np.sum(pred_cls) + np.sum(target_cls)
|
|
|
|
# 避免除零错误
|
|
if sum_areas == 0:
|
|
dice_per_class[cls] = 0.0
|
|
else:
|
|
dice_per_class[cls] = intersection / sum_areas
|
|
|
|
# 计算平均Dice系数
|
|
mean_dice = np.mean(dice_per_class)
|
|
|
|
# 构建结果字典
|
|
result = {'mean_dice': mean_dice}
|
|
for cls in range(num_classes):
|
|
result[f'dice_class_{cls}'] = dice_per_class[cls]
|
|
|
|
return result
|
|
|
|
def calculate_pixel_accuracy(pred, target):
|
|
"""计算像素准确率
|
|
|
|
Args:
|
|
pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits或 (B, H, W) 的类别索引
|
|
target (torch.Tensor): 真实标签,形状为 (B, H, W)
|
|
|
|
Returns:
|
|
float: 像素准确率
|
|
"""
|
|
if pred.dim() == 4: # (B, C, H, W) logits
|
|
pred = torch.argmax(pred, dim=1) # (B, H, W)
|
|
|
|
# 计算准确率
|
|
correct = (pred == target).float().sum()
|
|
total = torch.numel(pred)
|
|
|
|
return (correct / total).item()
|
|
|
|
def calculate_metrics(pred, target, num_classes):
|
|
"""计算多种评估指标
|
|
|
|
Args:
|
|
pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits
|
|
target (torch.Tensor): 真实标签,形状为 (B, H, W)
|
|
num_classes (int): 类别数量
|
|
|
|
Returns:
|
|
dict: 包含多种评估指标的字典
|
|
"""
|
|
# 计算IoU
|
|
iou_metrics = calculate_iou(pred, target, num_classes)
|
|
|
|
# 计算Dice系数
|
|
dice_metrics = calculate_dice(pred, target, num_classes)
|
|
|
|
# 计算像素准确率
|
|
pixel_acc = calculate_pixel_accuracy(pred, target)
|
|
|
|
# 合并所有指标
|
|
metrics = {'pixel_acc': pixel_acc}
|
|
metrics.update(iou_metrics)
|
|
metrics.update(dice_metrics)
|
|
|
|
return metrics |