ai_project_v1/middleware/conver_segementation_mask.py

64 lines
2.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
def apply_mask_to_image(image_path, mask_path, output_path, alpha=0.5):
"""
将7类掩码应用到原始图像上支持3/4通道图像
参数:
image_path: 原始图像路径
mask_path: 掩码图像路径(单通道像素值0-6表示7类)
output_path: 输出图像路径
alpha: 掩码透明度(0-1)
"""
# 读取原始图像(保留原始通道数)
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
if image is None:
raise ValueError("无法读取原始图像")
# 读取掩码图像(单通道)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError("无法读取掩码图像")
# 确保掩码和图像尺寸相同
if mask.shape != image.shape[:2]:
mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
# 定义7类颜色映射(BGR格式)
colors = [
(0, 0, 0), # 0类: 黑色(背景)
(0, 0, 255), # 1类: 红色
(0, 255, 0), # 2类: 绿色
(255, 0, 0), # 3类: 蓝色
(255, 255, 0), # 4类: 青色
(255, 0, 255), # 5类: 紫色
(0, 255, 255) # 6类: 黄色
]
# 创建彩色掩码3通道
color_mask_3ch = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
for class_id in range(7):
color_mask_3ch[mask == class_id] = colors[class_id]
# 处理原始图像
if len(image.shape) == 3 and image.shape[2] == 4: # 4通道图像
# 将3通道彩色掩码转换为4通道添加Alpha通道
color_mask_4ch = cv2.cvtColor(color_mask_3ch, cv2.COLOR_BGR2BGRA)
# 设置Alpha通道透明度
color_mask_4ch[:, :, 3] = np.clip(alpha * 255, 0, 255).astype(np.uint8)
# 合成图像
result = cv2.addWeighted(image[:, :, :3], 1 - alpha, color_mask_3ch, alpha, 0)
# 合并Alpha通道
result = cv2.cvtColor(result, cv2.COLOR_BGR2BGRA)
result[:, :, 3] = image[:, :, 3] # 保留原始Alpha通道
else: # 3通道或灰度图像
result = cv2.addWeighted(image[:, :, :3] if len(image.shape) == 3 else image,
1 - alpha, color_mask_3ch, alpha, 0)
# 保存结果
cv2.imwrite(output_path, result)
print(f"结果已保存到 {output_path}")