import torch import numpy as np def calculate_iou(pred, target, num_classes): """计算IoU (Intersection over Union) Args: pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits或 (B, H, W) 的类别索引 target (torch.Tensor): 真实标签,形状为 (B, H, W) num_classes (int): 类别数量 Returns: dict: 每个类别的IoU和平均IoU """ if pred.dim() == 4: # (B, C, H, W) logits pred = torch.argmax(pred, dim=1) # (B, H, W) # 将张量移动到CPU并转换为NumPy数组 pred = pred.detach().cpu().numpy() target = target.detach().cpu().numpy() # 初始化结果 iou_per_class = np.zeros(num_classes) # 计算每个类别的IoU for cls in range(num_classes): pred_cls = (pred == cls).astype(np.int8) target_cls = (target == cls).astype(np.int8) intersection = np.sum(pred_cls & target_cls) union = np.sum(pred_cls | target_cls) # 避免除零错误 if union == 0: iou_per_class[cls] = 0.0 else: iou_per_class[cls] = intersection / union # 计算平均IoU (mIoU) miou = np.mean(iou_per_class) # 构建结果字典 result = {'mIoU': miou} for cls in range(num_classes): result[f'IoU_class_{cls}'] = iou_per_class[cls] return result def calculate_dice(pred, target, num_classes): """计算Dice系数 Args: pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits或 (B, H, W) 的类别索引 target (torch.Tensor): 真实标签,形状为 (B, H, W) num_classes (int): 类别数量 Returns: dict: 每个类别的Dice系数和平均Dice系数 """ if pred.dim() == 4: # (B, C, H, W) logits pred = torch.argmax(pred, dim=1) # (B, H, W) # 将张量移动到CPU并转换为NumPy数组 pred = pred.detach().cpu().numpy() target = target.detach().cpu().numpy() # 初始化结果 dice_per_class = np.zeros(num_classes) # 计算每个类别的Dice系数 for cls in range(num_classes): pred_cls = (pred == cls).astype(np.int8) target_cls = (target == cls).astype(np.int8) intersection = 2.0 * np.sum(pred_cls & target_cls) sum_areas = np.sum(pred_cls) + np.sum(target_cls) # 避免除零错误 if sum_areas == 0: dice_per_class[cls] = 0.0 else: dice_per_class[cls] = intersection / sum_areas # 计算平均Dice系数 mean_dice = np.mean(dice_per_class) # 构建结果字典 result = {'mean_dice': mean_dice} for cls in range(num_classes): result[f'dice_class_{cls}'] = dice_per_class[cls] return result def calculate_pixel_accuracy(pred, target): """计算像素准确率 Args: pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits或 (B, H, W) 的类别索引 target (torch.Tensor): 真实标签,形状为 (B, H, W) Returns: float: 像素准确率 """ if pred.dim() == 4: # (B, C, H, W) logits pred = torch.argmax(pred, dim=1) # (B, H, W) # 计算准确率 correct = (pred == target).float().sum() total = torch.numel(pred) return (correct / total).item() def calculate_metrics(pred, target, num_classes): """计算多种评估指标 Args: pred (torch.Tensor): 预测结果,形状为 (B, C, H, W) 的logits target (torch.Tensor): 真实标签,形状为 (B, H, W) num_classes (int): 类别数量 Returns: dict: 包含多种评估指标的字典 """ # 计算IoU iou_metrics = calculate_iou(pred, target, num_classes) # 计算Dice系数 dice_metrics = calculate_dice(pred, target, num_classes) # 计算像素准确率 pixel_acc = calculate_pixel_accuracy(pred, target) # 合并所有指标 metrics = {'pixel_acc': pixel_acc} metrics.update(iou_metrics) metrics.update(dice_metrics) return metrics