ai_project_v1/CropLand_CD_module/visualize_pil_segmentation_mask.py

528 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# import json
# import os
# import time
#
# import cv2
# import rasterio
# import numpy as np
# from pyproj import Transformer, CRS
#
#
# def convert_to_wgs84(x, y, src_crs):
# """将投影坐标转换为WGS84经纬度增加错误处理"""
# try:
# if not (isinstance(x, (int, float)) and isinstance(y, (int, float))):
# print(f"警告: 坐标值无效 (x={x}, y={y})")
# return None, None
#
# transformer = Transformer.from_crs(src_crs, "EPSG:4326", always_xy=True)
# lon, lat = transformer.transform(x, y)
#
# if not np.isfinite(lon) or not np.isfinite(lat):
# print(f"警告: 坐标转换结果无效 (lon={lon}, lat={lat})")
# return None, None
#
# return lon, lat
# except Exception as e:
# print(f"坐标转换异常: {str(e)}")
# return None, None
#
#
# def visualize_pil_segmentation_mask_opencv(mask_path, tif_path, output_path=None, colormap=cv2.COLORMAP_VIRIDIS, save=True):
# """
# 使用OpenCV实现掩码可视化+边界提取避免Matplotlib后端问题
# 返回: (src_crs, json_result_path, vis_output_path, raw_json_path)
# """
# # 初始化返回值
# src_crs = None
# json_result_path = None
# vis_output_path = None
# raw_json_path = None
#
# try:
# # 1. 读取并验证掩码文件
# if not os.path.exists(mask_path):
# raise FileNotFoundError(f"掩码文件不存在: {mask_path}")
#
# mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# if mask is None:
# raise ValueError(f"无法读取掩码文件(可能已损坏): {mask_path}")
#
# # 2. 可视化处理
# mask_vis = cv2.applyColorMap(mask, colormap)
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# cv2.drawContours(mask_vis, contours, -1, (0, 0, 255), 1)
#
# # 3. 设置输出路径
# if output_path is None:
# vis_output_path = os.path.splitext(mask_path)[0] + '_vis_cv2.png'
# else:
# vis_output_path = output_path
#
# # 4. 保存可视化结果
# if save:
# if not cv2.imwrite(vis_output_path, mask_vis):
# raise IOError(f"无法保存可视化结果到: {vis_output_path}")
# print(f"可视化结果已保存到: {vis_output_path}")
#
# # 5. 处理轮廓数据
# instances = []
# for i, contour in enumerate(contours):
# if len(contour) < 1:
# print(f"警告: 轮廓 {i} 为空,跳过")
# continue
#
# # 简化轮廓(可选,减少点数)
# epsilon = 0.001 * cv2.arcLength(contour, True)
# approx_contour = cv2.approxPolyDP(contour, epsilon, True)
#
# # 转换为列表格式JSON 兼容)
# contour_list = []
# for point in approx_contour.squeeze():
# if isinstance(point, np.ndarray) and point.size >= 2:
# contour_list.append(point.tolist())
# elif isinstance(point, (list, tuple)) and len(point) >= 2:
# contour_list.append(list(map(int, point[:2]))) # 确保坐标是整数
# else:
# print(f"警告: 无效的轮廓点格式: {point}")
# continue
#
# if len(contour_list)> 3:
# contour_list.append(contour_list[0]) #多加一个点,构成封闭平面
# instances.append({
# "instance_id": i + 1,
# "contour": contour_list,
# "area": int(cv2.contourArea(contour))
# })
#
# if not instances:
# print("警告: 未检测到任何有效轮廓")
# return None, None, vis_output_path, None
#
# # 6. 保存原始JSON
# raw_json_path = os.path.splitext(mask_path)[0] + '.json'
# try:
# with open(raw_json_path, 'w') as f:
# json.dump(instances, f, indent=2)
# print(f"基础轮廓数据已保存到: {raw_json_path}")
# except Exception as e:
# print(f"警告: 无法保存基础JSON文件 - {str(e)}")
# raw_json_path = None
# # start = time.perf_counter()
# # count=0
# # 7. 处理TIFF坐标转换仅当提供有效tif_path时
# if tif_path and isinstance(tif_path, str) and tif_path.lower().endswith(('.tif', '.tiff')):
# try:
# with rasterio.open(tif_path) as src:
# src_crs = src.crs
# if not src_crs:
# print("警告: TIFF文件缺少坐标系信息跳过坐标转换")
# json_result_path = raw_json_path
# return src_crs, json_result_path, vis_output_path, raw_json_path
#
# transform = src.transform
# band_data = src.read(1)
# height, width = band_data.shape
#
# for instance in instances:
# coord = []
# for point in instance["contour"]:
# try:
# col, row = point[:2] # 提取坐标
# if not (0 <= col < width and 0 <= row < height):
# print(f"警告: 坐标({col}, {row})超出图像范围 ({width}x{height})")
# coord.append(None)
# continue
# # count=count+1
# # 像素坐标 → 投影坐标
# x, y = transform * (col, row)
# lon, lat = convert_to_wgs84(x, y, src_crs)
#
# # 获取像素值
# z = float(band_data[int(row), int(col)])
#
# if lon is not None and lat is not None:
# coord.append([lon, lat, z])
# else:
# coord.append([None, None, z])
# except Exception as e:
# print(f"警告: 处理轮廓点 {point} 时出错 - {str(e)}")
# coord.append(None)
#
# instance["coord"] = [[c for c in coord if c is not None]] # 过滤无效坐标
#
# # 保存带坐标的结果
# json_result_path = os.path.join(
# os.path.dirname(raw_json_path),
# "result_" + os.path.basename(raw_json_path)
# )
#
# try:
# with open(json_result_path, 'w') as f:
# json.dump(instances, f, indent=2, default=lambda obj: float(obj) if isinstance(obj, np.generic) else str(obj))
# print(f"完整结果(含坐标)已保存到: {json_result_path}")
# except Exception as e:
# print(f"错误: 无法保存结果JSON文件 - {str(e)}")
# json_result_path = raw_json_path # 回退到原始JSON
#
# except Exception as e:
# print(f"处理TIFF坐标时发生错误: {str(e)}")
# json_result_path = raw_json_path # 回退到原始JSON
# else:
# json_result_path = raw_json_path # 非TIFF文件直接返回原始JSON
# # end = time.perf_counter()
# # print(f"总计 {count} 点,总计耗时: {end - start:.6f} 秒")
#
# return src_crs, json_result_path, vis_output_path, raw_json_path
#
# except Exception as e:
# print(f"可视化处理过程中发生严重错误: {str(e)}")
# return None, None, vis_output_path, None # 确保返回路径变量
import json
import os
import numpy as np
import cv2
import rasterio
from pyproj import Transformer, CRS
from geographiclib.geodesic import Geodesic # 用于精确计算球面多边形面积
from pyproj import Geod
def calculate_polygon_area(coords):
"""计算球面多边形面积(单位:平方米)"""
if len(coords) < 3:
return 0.0
try:
geod = Geod(ellps="WGS84") # WGS84 椭球体
# 注意: pyproj.Geod 要求坐标顺序为 (lon, lat)
lons, lats = zip(*[(lon, lat) for lon, lat in coords])
area, _ = geod.polygon_area_perimeter(lons, lats)
return abs(area) # 返回绝对值(平方米)
except Exception as e:
print(f"面积计算异常: {str(e)}")
return 0.0
def convert_to_wgs84(x, y, src_crs):
"""将投影坐标转换为WGS84经纬度增加错误处理"""
try:
if not (isinstance(x, (int, float)) and isinstance(y, (int, float))):
print(f"警告: 坐标值无效 (x={x}, y={y})")
return None, None
transformer = Transformer.from_crs(src_crs, "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(x, y)
if not np.isfinite(lon) or not np.isfinite(lat):
print(f"警告: 坐标转换结果无效 (lon={lon}, lat={lat})")
return None, None
return lon, lat
except Exception as e:
print(f"坐标转换异常: {str(e)}")
return None, None
#
# def visualize_pil_segmentation_mask_opencv(mask_path, tif_path, output_path=None, colormap=cv2.COLORMAP_VIRIDIS, save=True):
# """
# 修改版:过滤面积 < 10㎡ 的轮廓,并更新可视化结果
# """
# src_crs = None
# json_result_path = None
# vis_output_path = None
# raw_json_path = None
#
# try:
# # 1. 读取掩码文件
# if not os.path.exists(mask_path):
# raise FileNotFoundError(f"掩码文件不存在: {mask_path}")
#
# mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# if mask is None:
# raise ValueError(f"无法读取掩码文件(可能已损坏): {mask_path}")
#
# # 2. 初始化可视化图像(后续会更新)
# mask_vis = cv2.applyColorMap(mask, colormap)
# contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# cv2.drawContours(mask_vis, contours, -1, (0, 0, 255), 1)
#
# # 3. 处理轮廓数据
# valid_instances = []
# filtered_contours = [] # 存储过滤后的轮廓(用于可视化)
#
# for i, contour in enumerate(contours):
# if len(contour) < 3: # 至少需要3个点构成多边形
# print(f"警告: 轮廓 {i} 点数不足,跳过")
# continue
#
# # 简化轮廓
# epsilon = 0.001 * cv2.arcLength(contour, True)
# approx_contour = cv2.approxPolyDP(contour, epsilon, True)
#
# # 转换为列表格式
# contour_list = []
# for point in approx_contour.squeeze():
# if isinstance(point, np.ndarray) and point.size >= 2:
# contour_list.append(point.tolist())
# elif isinstance(point, (list, tuple)) and len(point) >= 2:
# contour_list.append(list(map(int, point[:2])))
# else:
# print(f"警告: 无效的轮廓点格式: {point}")
# continue
#
# if len(contour_list) < 3:
# print(f"警告: 轮廓 {i} 简化后点数不足,跳过")
# continue
#
# # 临时存储当前轮廓(后续可能被过滤)
# temp_instance = {
# "instance_id": i + 1,
# "contour": contour_list,
# "area_pixels": int(cv2.contourArea(contour)),
# "coord": None # 后续填充
# }
#
# # 4. 处理TIFF坐标转换如果提供
# if tif_path and tif_path.lower().endswith(('.tif', '.tiff')):
# try:
# with rasterio.open(tif_path) as src:
# src_crs = src.crs
# if not src_crs:
# print("警告: TIFF文件缺少坐标系信息跳过坐标转换")
# valid_instances.append(temp_instance)
# filtered_contours.append(contour)
# continue
#
# transform = src.transform
# height, width = src.read(1).shape
# coords = []
#
# # 修复后的代码
# valid_coords = []
# for point in contour_list:
# col, row = point[:2]
# # 检查坐标是否在图像范围内
# if not (0 <= col < width and 0 <= row < height):
# print(f"警告: 坐标({col}, {row})超出图像范围")
# continue # 跳过这个点但不中断整个轮廓
#
# try:
# # 转换为投影坐标
# x, y = transform * (col, row)
# # 转换为经纬度
# lon, lat = convert_to_wgs84(x, y, src_crs)
# if lon is not None and lat is not None:
# valid_coords.append([lon, lat])
# except Exception as e:
# print(f"坐标处理错误: {str(e)}")
# continue
#
# # 过滤无效坐标
# valid_coords = [c for c in coords if c is not None]
# if len(valid_coords) < 3:
# print(f"警告: 轮廓 {i} 有效坐标不足,跳过")
# continue
#
# # 计算面积(平方米)
# area_m2 = calculate_polygon_area(valid_coords)
# temp_instance["coord"] = valid_coords
# temp_instance["area_m2"] = area_m2
#
# # 过滤面积 < 10㎡ 的轮廓
# if area_m2 >= 10:
# valid_instances.append(temp_instance)
# filtered_contours.append(contour) # 保留轮廓用于可视化
# else:
# print(f"过滤小面积轮廓: ID={i+1}, 面积={area_m2:.2f}㎡")
#
# except Exception as e:
# print(f"处理TIFF坐标时发生错误: {str(e)}")
# valid_instances.append(temp_instance) # 保留原始数据
# filtered_contours.append(contour)
# else:
# # 非TIFF文件直接保留轮廓
# valid_instances.append(temp_instance)
# filtered_contours.append(contour)
#
# # 5. 更新可视化图像(仅绘制保留的轮廓)
# mask_vis_filtered = cv2.applyColorMap(mask, colormap)
# if filtered_contours:
# cv2.drawContours(mask_vis_filtered, filtered_contours, -1, (0, 0, 255), 1)
#
# # 6. 保存可视化结果
# if output_path is None:
# vis_output_path = os.path.splitext(mask_path)[0] + '_vis_cv2.png'
# else:
# vis_output_path = output_path
#
# if save:
# if not cv2.imwrite(vis_output_path, mask_vis_filtered):
# raise IOError(f"无法保存可视化结果到: {vis_output_path}")
# print(f"可视化结果已保存到: {vis_output_path}")
#
# # 7. 保存JSON数据
# raw_json_path = os.path.splitext(mask_path)[0] + '.json'
# with open(raw_json_path, 'w') as f:
# json.dump(valid_instances, f, indent=2)
# print(f"基础轮廓数据已保存到: {raw_json_path}")
#
# # 8. 返回结果src_crs 可能为None
# return src_crs, raw_json_path, vis_output_path, raw_json_path
#
# except Exception as e:
# print(f"可视化处理过程中发生严重错误: {str(e)}")
# return None, None, vis_output_path, None
def visualize_pil_segmentation_mask_opencv(mask_path, tif_path, output_path=None, colormap=cv2.COLORMAP_VIRIDIS, save=True):
"""
修改版:过滤面积 < 10㎡ 的轮廓,并更新可视化结果
"""
src_crs = None
json_result_path = None
vis_output_path = None
raw_json_path = None
try:
# 1. 读取掩码文件(同上)
if not os.path.exists(mask_path):
raise FileNotFoundError(f"掩码文件不存在: {mask_path}")
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"无法读取掩码文件(可能已损坏): {mask_path}")
# 2. 初始化可视化图像
mask_vis = cv2.applyColorMap(mask, colormap)
contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(mask_vis, contours, -1, (0, 0, 255), 1)
# 3. 处理轮廓数据
valid_instances = []
filtered_contours = []
for i, contour in enumerate(contours):
if len(contour) < 3:
print(f"警告: 轮廓 {i} 点数不足,跳过")
continue
# 简化轮廓
epsilon = 0.001 * cv2.arcLength(contour, True)
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
# 转换为列表格式并确保闭合
contour_list = []
for point in approx_contour.squeeze():
if isinstance(point, np.ndarray) and point.size >= 2:
contour_list.append(point.tolist())
elif isinstance(point, (list, tuple)) and len(point) >= 2:
contour_list.append(list(map(float, point[:2]))) # 使用float保持精度
if len(contour_list) < 3:
print(f"警告: 轮廓 {i} 简化后点数不足,跳过")
continue
# 闭合多边形
# if contour_list[0] != contour_list[-1]:
# contour_list.append(contour_list[0])
contour_list.append(contour_list[0])
temp_instance = {
"instance_id": i + 1,
"contour": contour_list,
"area_pixels": int(cv2.contourArea(contour)),
"coord": None
}
# 4. 处理TIFF坐标转换
if tif_path and tif_path.lower().endswith(('.tif', '.tiff')):
try:
with rasterio.open(tif_path) as src:
src_crs = src.crs
if not src_crs:
print("警告: TIFF文件缺少坐标系信息")
valid_instances.append(temp_instance)
filtered_contours.append(contour)
continue
transform = src.transform
height, width = src.read(1).shape
all_coords = []
valid_coords = []
for point in contour_list[:-1]: # 不处理重复的闭合点
col, row = point[:2]
if not (0 <= col < width and 0 <= row < height):
print(f"警告: 坐标({col}, {row})超出图像范围")
continue
try:
x, y = transform * (col, row)
lon, lat = convert_to_wgs84(x, y, src_crs)
if lon is not None and lat is not None:
all_coords.append([lon, lat])
except Exception as e:
print(f"坐标处理错误: {str(e)}")
continue
# 必须至少有3个有效坐标才能构成多边形
if len(all_coords) >= 3:
# 计算面积(平方米)
area_m2 = calculate_polygon_area(all_coords)
# # 闭合多边形
all_coords.append(all_coords[0])
# 适配geojson的三层格式
temp_instance["coord"] = [all_coords]
temp_instance["area_m2"] = area_m2
# 过滤面积 < 10㎡ 的轮廓
if area_m2 >= 10:
valid_instances.append(temp_instance)
filtered_contours.append(contour)
else:
print(f"过滤小面积轮廓: ID={i+1}, 面积={area_m2:.2f}")
else:
print(f"警告: 轮廓 {i} 有效坐标不足3个")
except Exception as e:
print(f"处理TIFF坐标时发生错误: {str(e)}")
valid_instances.append(temp_instance)
filtered_contours.append(contour)
else:
valid_instances.append(temp_instance)
filtered_contours.append(contour)
# 5. 更新可视化图像
mask_vis_filtered = cv2.applyColorMap(mask, colormap)
if filtered_contours:
cv2.drawContours(mask_vis_filtered, filtered_contours, -1, (0, 0, 255), 1)
# ...(剩余的保存和返回逻辑保持不变)
# 6. 保存可视化结果
if output_path is None:
vis_output_path = os.path.splitext(mask_path)[0] + '_vis_cv2.png'
else:
vis_output_path = output_path
if save:
if not cv2.imwrite(vis_output_path, mask_vis_filtered):
raise IOError(f"无法保存可视化结果到: {vis_output_path}")
print(f"可视化结果已保存到: {vis_output_path}")
# 7. 保存JSON数据
raw_json_path = os.path.splitext(mask_path)[0] + '.json'
with open(raw_json_path, 'w') as f:
json.dump(valid_instances, f, indent=2)
print(f"基础轮廓数据已保存到: {raw_json_path}")
# 8. 返回结果src_crs 可能为None
return src_crs, raw_json_path, vis_output_path, raw_json_path
except Exception as e:
print(f"可视化处理过程中发生严重错误: {str(e)}")
return None, None, vis_output_path, None