ai_project_v1/CropLand_CD_module/visualize_pil_segmentation_mask.py

528 lines
22 KiB
Python
Raw Normal View History

#
# import json
# import os
# import time
#
# import cv2
# import rasterio
# import numpy as np
# from pyproj import Transformer, CRS
#
#
# def convert_to_wgs84(x, y, src_crs):
# """将投影坐标转换为WGS84经纬度增加错误处理"""
# try:
# if not (isinstance(x, (int, float)) and isinstance(y, (int, float))):
# print(f"警告: 坐标值无效 (x={x}, y={y})")
# return None, None
#
# transformer = Transformer.from_crs(src_crs, "EPSG:4326", always_xy=True)
# lon, lat = transformer.transform(x, y)
#
# if not np.isfinite(lon) or not np.isfinite(lat):
# print(f"警告: 坐标转换结果无效 (lon={lon}, lat={lat})")
# return None, None
#
# return lon, lat
# except Exception as e:
# print(f"坐标转换异常: {str(e)}")
# return None, None
#
#
# def visualize_pil_segmentation_mask_opencv(mask_path, tif_path, output_path=None, colormap=cv2.COLORMAP_VIRIDIS, save=True):
# """
# 使用OpenCV实现掩码可视化+边界提取避免Matplotlib后端问题
# 返回: (src_crs, json_result_path, vis_output_path, raw_json_path)
# """
# # 初始化返回值
# src_crs = None
# json_result_path = None
# vis_output_path = None
# raw_json_path = None
#
# try:
# # 1. 读取并验证掩码文件
# if not os.path.exists(mask_path):
# raise FileNotFoundError(f"掩码文件不存在: {mask_path}")
#
# mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# if mask is None:
# raise ValueError(f"无法读取掩码文件(可能已损坏): {mask_path}")
#
# # 2. 可视化处理
# mask_vis = cv2.applyColorMap(mask, colormap)
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# cv2.drawContours(mask_vis, contours, -1, (0, 0, 255), 1)
#
# # 3. 设置输出路径
# if output_path is None:
# vis_output_path = os.path.splitext(mask_path)[0] + '_vis_cv2.png'
# else:
# vis_output_path = output_path
#
# # 4. 保存可视化结果
# if save:
# if not cv2.imwrite(vis_output_path, mask_vis):
# raise IOError(f"无法保存可视化结果到: {vis_output_path}")
# print(f"可视化结果已保存到: {vis_output_path}")
#
# # 5. 处理轮廓数据
# instances = []
# for i, contour in enumerate(contours):
# if len(contour) < 1:
# print(f"警告: 轮廓 {i} 为空,跳过")
# continue
#
# # 简化轮廓(可选,减少点数)
# epsilon = 0.001 * cv2.arcLength(contour, True)
# approx_contour = cv2.approxPolyDP(contour, epsilon, True)
#
# # 转换为列表格式JSON 兼容)
# contour_list = []
# for point in approx_contour.squeeze():
# if isinstance(point, np.ndarray) and point.size >= 2:
# contour_list.append(point.tolist())
# elif isinstance(point, (list, tuple)) and len(point) >= 2:
# contour_list.append(list(map(int, point[:2]))) # 确保坐标是整数
# else:
# print(f"警告: 无效的轮廓点格式: {point}")
# continue
#
# if len(contour_list)> 3:
# contour_list.append(contour_list[0]) #多加一个点,构成封闭平面
# instances.append({
# "instance_id": i + 1,
# "contour": contour_list,
# "area": int(cv2.contourArea(contour))
# })
#
# if not instances:
# print("警告: 未检测到任何有效轮廓")
# return None, None, vis_output_path, None
#
# # 6. 保存原始JSON
# raw_json_path = os.path.splitext(mask_path)[0] + '.json'
# try:
# with open(raw_json_path, 'w') as f:
# json.dump(instances, f, indent=2)
# print(f"基础轮廓数据已保存到: {raw_json_path}")
# except Exception as e:
# print(f"警告: 无法保存基础JSON文件 - {str(e)}")
# raw_json_path = None
# # start = time.perf_counter()
# # count=0
# # 7. 处理TIFF坐标转换仅当提供有效tif_path时
# if tif_path and isinstance(tif_path, str) and tif_path.lower().endswith(('.tif', '.tiff')):
# try:
# with rasterio.open(tif_path) as src:
# src_crs = src.crs
# if not src_crs:
# print("警告: TIFF文件缺少坐标系信息跳过坐标转换")
# json_result_path = raw_json_path
# return src_crs, json_result_path, vis_output_path, raw_json_path
#
# transform = src.transform
# band_data = src.read(1)
# height, width = band_data.shape
#
# for instance in instances:
# coord = []
# for point in instance["contour"]:
# try:
# col, row = point[:2] # 提取坐标
# if not (0 <= col < width and 0 <= row < height):
# print(f"警告: 坐标({col}, {row})超出图像范围 ({width}x{height})")
# coord.append(None)
# continue
# # count=count+1
# # 像素坐标 → 投影坐标
# x, y = transform * (col, row)
# lon, lat = convert_to_wgs84(x, y, src_crs)
#
# # 获取像素值
# z = float(band_data[int(row), int(col)])
#
# if lon is not None and lat is not None:
# coord.append([lon, lat, z])
# else:
# coord.append([None, None, z])
# except Exception as e:
# print(f"警告: 处理轮廓点 {point} 时出错 - {str(e)}")
# coord.append(None)
#
# instance["coord"] = [[c for c in coord if c is not None]] # 过滤无效坐标
#
# # 保存带坐标的结果
# json_result_path = os.path.join(
# os.path.dirname(raw_json_path),
# "result_" + os.path.basename(raw_json_path)
# )
#
# try:
# with open(json_result_path, 'w') as f:
# json.dump(instances, f, indent=2, default=lambda obj: float(obj) if isinstance(obj, np.generic) else str(obj))
# print(f"完整结果(含坐标)已保存到: {json_result_path}")
# except Exception as e:
# print(f"错误: 无法保存结果JSON文件 - {str(e)}")
# json_result_path = raw_json_path # 回退到原始JSON
#
# except Exception as e:
# print(f"处理TIFF坐标时发生错误: {str(e)}")
# json_result_path = raw_json_path # 回退到原始JSON
# else:
# json_result_path = raw_json_path # 非TIFF文件直接返回原始JSON
# # end = time.perf_counter()
# # print(f"总计 {count} 点,总计耗时: {end - start:.6f} 秒")
#
# return src_crs, json_result_path, vis_output_path, raw_json_path
#
# except Exception as e:
# print(f"可视化处理过程中发生严重错误: {str(e)}")
# return None, None, vis_output_path, None # 确保返回路径变量
import json
import os
import numpy as np
import cv2
import rasterio
from pyproj import Transformer, CRS
from geographiclib.geodesic import Geodesic # 用于精确计算球面多边形面积
from pyproj import Geod
def calculate_polygon_area(coords):
"""计算球面多边形面积(单位:平方米)"""
if len(coords) < 3:
return 0.0
try:
geod = Geod(ellps="WGS84") # WGS84 椭球体
# 注意: pyproj.Geod 要求坐标顺序为 (lon, lat)
lons, lats = zip(*[(lon, lat) for lon, lat in coords])
area, _ = geod.polygon_area_perimeter(lons, lats)
return abs(area) # 返回绝对值(平方米)
except Exception as e:
print(f"面积计算异常: {str(e)}")
return 0.0
def convert_to_wgs84(x, y, src_crs):
"""将投影坐标转换为WGS84经纬度增加错误处理"""
try:
if not (isinstance(x, (int, float)) and isinstance(y, (int, float))):
print(f"警告: 坐标值无效 (x={x}, y={y})")
return None, None
transformer = Transformer.from_crs(src_crs, "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(x, y)
if not np.isfinite(lon) or not np.isfinite(lat):
print(f"警告: 坐标转换结果无效 (lon={lon}, lat={lat})")
return None, None
return lon, lat
except Exception as e:
print(f"坐标转换异常: {str(e)}")
return None, None
#
# def visualize_pil_segmentation_mask_opencv(mask_path, tif_path, output_path=None, colormap=cv2.COLORMAP_VIRIDIS, save=True):
# """
# 修改版:过滤面积 < 10㎡ 的轮廓,并更新可视化结果
# """
# src_crs = None
# json_result_path = None
# vis_output_path = None
# raw_json_path = None
#
# try:
# # 1. 读取掩码文件
# if not os.path.exists(mask_path):
# raise FileNotFoundError(f"掩码文件不存在: {mask_path}")
#
# mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# if mask is None:
# raise ValueError(f"无法读取掩码文件(可能已损坏): {mask_path}")
#
# # 2. 初始化可视化图像(后续会更新)
# mask_vis = cv2.applyColorMap(mask, colormap)
# contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# cv2.drawContours(mask_vis, contours, -1, (0, 0, 255), 1)
#
# # 3. 处理轮廓数据
# valid_instances = []
# filtered_contours = [] # 存储过滤后的轮廓(用于可视化)
#
# for i, contour in enumerate(contours):
# if len(contour) < 3: # 至少需要3个点构成多边形
# print(f"警告: 轮廓 {i} 点数不足,跳过")
# continue
#
# # 简化轮廓
# epsilon = 0.001 * cv2.arcLength(contour, True)
# approx_contour = cv2.approxPolyDP(contour, epsilon, True)
#
# # 转换为列表格式
# contour_list = []
# for point in approx_contour.squeeze():
# if isinstance(point, np.ndarray) and point.size >= 2:
# contour_list.append(point.tolist())
# elif isinstance(point, (list, tuple)) and len(point) >= 2:
# contour_list.append(list(map(int, point[:2])))
# else:
# print(f"警告: 无效的轮廓点格式: {point}")
# continue
#
# if len(contour_list) < 3:
# print(f"警告: 轮廓 {i} 简化后点数不足,跳过")
# continue
#
# # 临时存储当前轮廓(后续可能被过滤)
# temp_instance = {
# "instance_id": i + 1,
# "contour": contour_list,
# "area_pixels": int(cv2.contourArea(contour)),
# "coord": None # 后续填充
# }
#
# # 4. 处理TIFF坐标转换如果提供
# if tif_path and tif_path.lower().endswith(('.tif', '.tiff')):
# try:
# with rasterio.open(tif_path) as src:
# src_crs = src.crs
# if not src_crs:
# print("警告: TIFF文件缺少坐标系信息跳过坐标转换")
# valid_instances.append(temp_instance)
# filtered_contours.append(contour)
# continue
#
# transform = src.transform
# height, width = src.read(1).shape
# coords = []
#
# # 修复后的代码
# valid_coords = []
# for point in contour_list:
# col, row = point[:2]
# # 检查坐标是否在图像范围内
# if not (0 <= col < width and 0 <= row < height):
# print(f"警告: 坐标({col}, {row})超出图像范围")
# continue # 跳过这个点但不中断整个轮廓
#
# try:
# # 转换为投影坐标
# x, y = transform * (col, row)
# # 转换为经纬度
# lon, lat = convert_to_wgs84(x, y, src_crs)
# if lon is not None and lat is not None:
# valid_coords.append([lon, lat])
# except Exception as e:
# print(f"坐标处理错误: {str(e)}")
# continue
#
# # 过滤无效坐标
# valid_coords = [c for c in coords if c is not None]
# if len(valid_coords) < 3:
# print(f"警告: 轮廓 {i} 有效坐标不足,跳过")
# continue
#
# # 计算面积(平方米)
# area_m2 = calculate_polygon_area(valid_coords)
# temp_instance["coord"] = valid_coords
# temp_instance["area_m2"] = area_m2
#
# # 过滤面积 < 10㎡ 的轮廓
# if area_m2 >= 10:
# valid_instances.append(temp_instance)
# filtered_contours.append(contour) # 保留轮廓用于可视化
# else:
# print(f"过滤小面积轮廓: ID={i+1}, 面积={area_m2:.2f}㎡")
#
# except Exception as e:
# print(f"处理TIFF坐标时发生错误: {str(e)}")
# valid_instances.append(temp_instance) # 保留原始数据
# filtered_contours.append(contour)
# else:
# # 非TIFF文件直接保留轮廓
# valid_instances.append(temp_instance)
# filtered_contours.append(contour)
#
# # 5. 更新可视化图像(仅绘制保留的轮廓)
# mask_vis_filtered = cv2.applyColorMap(mask, colormap)
# if filtered_contours:
# cv2.drawContours(mask_vis_filtered, filtered_contours, -1, (0, 0, 255), 1)
#
# # 6. 保存可视化结果
# if output_path is None:
# vis_output_path = os.path.splitext(mask_path)[0] + '_vis_cv2.png'
# else:
# vis_output_path = output_path
#
# if save:
# if not cv2.imwrite(vis_output_path, mask_vis_filtered):
# raise IOError(f"无法保存可视化结果到: {vis_output_path}")
# print(f"可视化结果已保存到: {vis_output_path}")
#
# # 7. 保存JSON数据
# raw_json_path = os.path.splitext(mask_path)[0] + '.json'
# with open(raw_json_path, 'w') as f:
# json.dump(valid_instances, f, indent=2)
# print(f"基础轮廓数据已保存到: {raw_json_path}")
#
# # 8. 返回结果src_crs 可能为None
# return src_crs, raw_json_path, vis_output_path, raw_json_path
#
# except Exception as e:
# print(f"可视化处理过程中发生严重错误: {str(e)}")
# return None, None, vis_output_path, None
def visualize_pil_segmentation_mask_opencv(mask_path, tif_path, output_path=None, colormap=cv2.COLORMAP_VIRIDIS, save=True):
"""
修改版过滤面积 < 10 的轮廓并更新可视化结果
"""
src_crs = None
json_result_path = None
vis_output_path = None
raw_json_path = None
try:
# 1. 读取掩码文件(同上)
if not os.path.exists(mask_path):
raise FileNotFoundError(f"掩码文件不存在: {mask_path}")
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"无法读取掩码文件(可能已损坏): {mask_path}")
# 2. 初始化可视化图像
mask_vis = cv2.applyColorMap(mask, colormap)
contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(mask_vis, contours, -1, (0, 0, 255), 1)
# 3. 处理轮廓数据
valid_instances = []
filtered_contours = []
for i, contour in enumerate(contours):
if len(contour) < 3:
print(f"警告: 轮廓 {i} 点数不足,跳过")
continue
# 简化轮廓
epsilon = 0.001 * cv2.arcLength(contour, True)
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
# 转换为列表格式并确保闭合
contour_list = []
for point in approx_contour.squeeze():
if isinstance(point, np.ndarray) and point.size >= 2:
contour_list.append(point.tolist())
elif isinstance(point, (list, tuple)) and len(point) >= 2:
contour_list.append(list(map(float, point[:2]))) # 使用float保持精度
if len(contour_list) < 3:
print(f"警告: 轮廓 {i} 简化后点数不足,跳过")
continue
# 闭合多边形
# if contour_list[0] != contour_list[-1]:
# contour_list.append(contour_list[0])
contour_list.append(contour_list[0])
temp_instance = {
"instance_id": i + 1,
"contour": contour_list,
"area_pixels": int(cv2.contourArea(contour)),
"coord": None
}
# 4. 处理TIFF坐标转换
if tif_path and tif_path.lower().endswith(('.tif', '.tiff')):
try:
with rasterio.open(tif_path) as src:
src_crs = src.crs
if not src_crs:
print("警告: TIFF文件缺少坐标系信息")
valid_instances.append(temp_instance)
filtered_contours.append(contour)
continue
transform = src.transform
height, width = src.read(1).shape
all_coords = []
valid_coords = []
for point in contour_list[:-1]: # 不处理重复的闭合点
col, row = point[:2]
if not (0 <= col < width and 0 <= row < height):
print(f"警告: 坐标({col}, {row})超出图像范围")
continue
try:
x, y = transform * (col, row)
lon, lat = convert_to_wgs84(x, y, src_crs)
if lon is not None and lat is not None:
all_coords.append([lon, lat])
except Exception as e:
print(f"坐标处理错误: {str(e)}")
continue
# 必须至少有3个有效坐标才能构成多边形
if len(all_coords) >= 3:
# 计算面积(平方米)
area_m2 = calculate_polygon_area(all_coords)
# # 闭合多边形
all_coords.append(all_coords[0])
# 适配geojson的三层格式
temp_instance["coord"] = [all_coords]
temp_instance["area_m2"] = area_m2
# 过滤面积 < 10㎡ 的轮廓
if area_m2 >= 10:
valid_instances.append(temp_instance)
filtered_contours.append(contour)
else:
print(f"过滤小面积轮廓: ID={i+1}, 面积={area_m2:.2f}")
else:
print(f"警告: 轮廓 {i} 有效坐标不足3个")
except Exception as e:
print(f"处理TIFF坐标时发生错误: {str(e)}")
valid_instances.append(temp_instance)
filtered_contours.append(contour)
else:
valid_instances.append(temp_instance)
filtered_contours.append(contour)
# 5. 更新可视化图像
mask_vis_filtered = cv2.applyColorMap(mask, colormap)
if filtered_contours:
cv2.drawContours(mask_vis_filtered, filtered_contours, -1, (0, 0, 255), 1)
# ...(剩余的保存和返回逻辑保持不变)
# 6. 保存可视化结果
if output_path is None:
vis_output_path = os.path.splitext(mask_path)[0] + '_vis_cv2.png'
else:
vis_output_path = output_path
if save:
if not cv2.imwrite(vis_output_path, mask_vis_filtered):
raise IOError(f"无法保存可视化结果到: {vis_output_path}")
print(f"可视化结果已保存到: {vis_output_path}")
# 7. 保存JSON数据
raw_json_path = os.path.splitext(mask_path)[0] + '.json'
with open(raw_json_path, 'w') as f:
json.dump(valid_instances, f, indent=2)
print(f"基础轮廓数据已保存到: {raw_json_path}")
# 8. 返回结果src_crs 可能为None
return src_crs, raw_json_path, vis_output_path, raw_json_path
except Exception as e:
print(f"可视化处理过程中发生严重错误: {str(e)}")
return None, None, vis_output_path, None