389 lines
15 KiB
Python
Raw Normal View History

import os
import cv2
import numpy as np
import rasterio
import geopandas as gpd
from shapely.geometry import Polygon
from rasterio.features import shapes
from scipy.ndimage import binary_opening, binary_closing, binary_fill_holes, label
from tqdm import tqdm
# class PredictionVisualizer:
# def __init__(self, image_path, pred_mask_path):
# self.image_path = image_path
# self.pred_mask_path = pred_mask_path
# # 检查输入图像的地理参考信息
# with rasterio.open(image_path) as src:
# self.has_georeference = src.transform.is_identity == False
# if not self.has_georeference:
# print("警告:输入图像没有地理参考信息,将使用像素坐标系统。")
# self.transform = src.transform
# self.crs = src.crs
# self.width = src.width
# self.height = src.height
# def post_process_mask(self, mask, min_area=100, kernel_size=(5,5)):
# mask = mask.astype(np.uint8)
# kernel = np.ones(kernel_size)
# mask = binary_closing(mask, structure=kernel)
# mask = binary_opening(mask, structure=np.ones((3, 3)))
# mask = binary_fill_holes(mask)
# labeled_array, num_features = label(mask)
# for i in range(1, num_features + 1):
# component = (labeled_array == i)
# if np.sum(component) < min_area:
# mask[component] = 0
# return mask.astype(np.uint8)
# def draw_contours_on_image(self, save_path=None, color=(0, 0, 255), thickness=2, approx_method=cv2.CHAIN_APPROX_TC89_KCOS):
# with rasterio.open(self.image_path) as src:
# image = src.read([1, 2, 3]).transpose(1, 2, 0)
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# pred_mask = cv2.imread(self.pred_mask_path, cv2.IMREAD_GRAYSCALE)
# _, binary = cv2.threshold(pred_mask, 127, 255, cv2.THRESH_BINARY)
# processed_mask = self.post_process_mask(binary, min_area=10)
# contours, _ = cv2.findContours(processed_mask, cv2.RETR_EXTERNAL, approx_method)
# simplified_contours = []
# for cnt in contours:
# if len(cnt) >= 3:
# epsilon = 0.001 * cv2.arcLength(cnt, True) if self.has_georeference else 0.002 * cv2.arcLength(cnt, True)
# approx = cv2.approxPolyDP(cnt, epsilon, True)
# if len(approx) >= 3:
# simplified_contours.append(approx.astype(np.int32))
# result = image.copy()
# cv2.drawContours(result, simplified_contours, -1, color, thickness)
# if save_path:
# cv2.imwrite(save_path, result)
# return result, simplified_contours
# def export_contours_to_shapefile(self, contours, output_path):
# if not self.has_georeference:
# print("当前图像没有地理参考信息,无法导出地理坐标。")
# return None
# polygons = []
# for cnt in contours:
# coords = []
# for point in cnt[:, 0, :]:
# x_pixel, y_pixel = point[0], point[1]
# x_geo, y_geo = rasterio.transform.xy(self.transform, y_pixel, x_pixel)
# coords.append((x_geo, y_geo))
# if coords[0] != coords[-1]:
# coords.append(coords[0])
# poly = Polygon(coords)
# if poly.is_valid:
# polygons.append(poly)
# if not polygons:
# print(f"未生成有效地理轮廓:{output_path}")
# return None
# gdf = gpd.GeoDataFrame({'geometry': polygons}, crs=self.crs)
# gdf.to_file(output_path)
# print(f"✅ 已保存地理红框到 {output_path}")
# return output_path
# def process_all_images(image1_dir, image2_dir, pred_mask_dir, save_dir, shp_dir):
# os.makedirs(save_dir, exist_ok=True)
# os.makedirs(shp_dir, exist_ok=True)
# mask_files = sorted([f for f in os.listdir(pred_mask_dir) if f.endswith('.png')])
# for mask_file in tqdm(mask_files, desc="处理影像中"):
# base_name = os.path.splitext(mask_file)[0]
# image1_path = os.path.join(image1_dir, base_name + '.png')
# image2_path = os.path.join(image2_dir, base_name + '.png')
# mask_path = os.path.join(pred_mask_dir, mask_file)
# save_path_img1 = os.path.join(save_dir, f'{base_name}_image1_red.png')
# save_path_img2 = os.path.join(save_dir, f'{base_name}_image2_red.png')
# shp_path = os.path.join(shp_dir, f'{base_name}.shp')
# if not os.path.exists(image1_path) or not os.path.exists(image2_path):
# print(f"[跳过] 缺失影像:{base_name}")
# continue
# # 1. 生成轮廓 + 导出shp基于image1
# visualizer = PredictionVisualizer(image1_path, mask_path)
# _, contours = visualizer.draw_contours_on_image(save_path=save_path_img1)
# visualizer.export_contours_to_shapefile(contours, shp_path)
# # 2. image2 也绘制红框(但不导出地理坐标)
# visualizer2 = PredictionVisualizer(image2_path, mask_path)
# visualizer2.draw_contours_on_image(save_path=save_path_img2)
# # === 设置路径并运行 ===
# if __name__ == '__main__':
# image1_dir = r'/media/data0/HL/2025-7-6/val/image1'
# image2_dir = r'/media/data0/HL/2025-7-6/val/image2'
# pred_mask_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save'
# save_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_red'
# shp_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_shp'
# process_all_images(image1_dir, image2_dir, pred_mask_dir, save_dir, shp_dir)
class PredictionVisualizer:
def __init__(self, image_path, pred_mask_path):
self.image_path = image_path
self.pred_mask_path = pred_mask_path
with rasterio.open(image_path) as src:
self.transform = src.transform
self.crs = src.crs
self.width = src.width
self.height = src.height
self.has_georeference = self.crs is not None and self.transform != rasterio.Affine.identity()
# if not self.has_georeference:
# print(f"警告:输入图像 {image_path} 没有地理参考信息,将使用像素坐标系统。")
def post_process_mask(self, mask, min_area=100, kernel_size=(5, 5)):
mask = mask.astype(np.uint8)
kernel = np.ones(kernel_size)
mask = binary_closing(mask, structure=kernel)
mask = binary_opening(mask, structure=np.ones((3, 3)))
mask = binary_fill_holes(mask)
labeled_array, num_features = label(mask)
for i in range(1, num_features + 1):
component = (labeled_array == i)
if np.sum(component) < min_area:
mask[component] = 0
return mask.astype(np.uint8)
def draw_contours_on_image(self, save_path=None, color=(0, 0, 255), thickness=2, approx_method=cv2.CHAIN_APPROX_TC89_KCOS):
with rasterio.open(self.image_path) as src:
image = src.read([1, 2, 3]).transpose(1, 2, 0)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
pred_mask = cv2.imread(self.pred_mask_path, cv2.IMREAD_GRAYSCALE)
if pred_mask is None:
raise FileNotFoundError(f"无法读取掩膜图像: {self.pred_mask_path}")
_, binary = cv2.threshold(pred_mask, 127, 255, cv2.THRESH_BINARY)
processed_mask = self.post_process_mask(binary, min_area=10)
contours, _ = cv2.findContours(processed_mask, cv2.RETR_EXTERNAL, approx_method)
simplified_contours = []
for cnt in contours:
if len(cnt) >= 3:
epsilon = 0.001 * cv2.arcLength(cnt, True) if self.has_georeference else 0.002 * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
if len(approx) >= 3:
simplified_contours.append(approx.astype(np.int32))
result = image.copy()
cv2.drawContours(result, simplified_contours, -1, color, thickness)
if save_path:
cv2.imwrite(save_path, result)
return result, simplified_contours
class ChangeVisualizer:
def __init__(self, image1_dir, image2_dir, pred_mask_dir, save_dir):
self.image1_dir = image1_dir
self.image2_dir = image2_dir
self.pred_mask_dir = pred_mask_dir
self.save_dir = save_dir
os.makedirs(save_dir, exist_ok=True)
@staticmethod
def find_corresponding_file(base_dir, base_name):
for ext in ['.png', '.tif', '.jpg', '.jpeg']:
path = os.path.join(base_dir, base_name + ext)
if os.path.exists(path):
return path
return None
def run(self):
mask_files = sorted([
f for f in os.listdir(self.pred_mask_dir)
if os.path.splitext(f)[1].lower() in ['.png', '.tif', '.jpg', '.jpeg']
])
for mask_file in tqdm(mask_files, desc="处理影像中"):
base_name = os.path.splitext(mask_file)[0]
image1_path = self.find_corresponding_file(self.image1_dir, base_name)
image2_path = self.find_corresponding_file(self.image2_dir, base_name)
mask_path = os.path.join(self.pred_mask_dir, mask_file)
if not image1_path or not image2_path:
print(f"[跳过] 缺失影像:{base_name}")
continue
save_path_img1 = os.path.join(self.save_dir, f'{base_name}_image1_red.png')
save_path_img2 = os.path.join(self.save_dir, f'{base_name}_image2_red.png')
# 1. 从 image1 获取轮廓
visualizer = PredictionVisualizer(image1_path, mask_path)
_, contours = visualizer.draw_contours_on_image(save_path=None)
# 2. 两期图像绘制红框
for img_path, out_path in [(image1_path, save_path_img1), (image2_path, save_path_img2)]:
visualizer_i = PredictionVisualizer(img_path, mask_path)
visualizer_i.draw_contours_on_image(save_path=out_path)
# # 主函数调用方式
# if __name__ == '__main__':
# image1_dir = r'/media/data0/HL/2025-7-6/val/image1'
# image2_dir = r'/media/data0/HL/2025-7-6/val/image2'
# pred_mask_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save'
# save_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_red'
# visualizer = ChangeVisualizer(image1_dir, image2_dir, pred_mask_dir, save_dir)
# visualizer.run()
# class PredictionVisualizer:
# def __init__(self, image_path, pred_mask_path):
# self.image_path = image_path
# self.pred_mask_path = pred_mask_path
# with rasterio.open(image_path) as src:
# self.transform = src.transform
# self.crs = src.crs
# self.width = src.width
# self.height = src.height
# self.has_georeference = self.crs is not None and self.transform != rasterio.Affine.identity()
# if not self.has_georeference:
# print(f"警告:输入图像 {image_path} 没有地理参考信息,将使用像素坐标系统。")
# def post_process_mask(self, mask, min_area=100, kernel_size=(5, 5)):
# mask = mask.astype(np.uint8)
# kernel = np.ones(kernel_size)
# mask = binary_closing(mask, structure=kernel)
# mask = binary_opening(mask, structure=np.ones((3, 3)))
# mask = binary_fill_holes(mask)
# labeled_array, num_features = label(mask)
# for i in range(1, num_features + 1):
# component = (labeled_array == i)
# if np.sum(component) < min_area:
# mask[component] = 0
# return mask.astype(np.uint8)
# def draw_contours_on_image(self, save_path=None, color=(0, 0, 255), thickness=2, approx_method=cv2.CHAIN_APPROX_TC89_KCOS):
# with rasterio.open(self.image_path) as src:
# image = src.read([1, 2, 3]).transpose(1, 2, 0)
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# pred_mask = cv2.imread(self.pred_mask_path, cv2.IMREAD_GRAYSCALE)
# if pred_mask is None:
# raise FileNotFoundError(f"无法读取掩膜图像: {self.pred_mask_path}")
# _, binary = cv2.threshold(pred_mask, 127, 255, cv2.THRESH_BINARY)
# processed_mask = self.post_process_mask(binary, min_area=10)
# contours, _ = cv2.findContours(processed_mask, cv2.RETR_EXTERNAL, approx_method)
# simplified_contours = []
# for cnt in contours:
# if len(cnt) >= 3:
# epsilon = 0.001 * cv2.arcLength(cnt, True) if self.has_georeference else 0.002 * cv2.arcLength(cnt, True)
# approx = cv2.approxPolyDP(cnt, epsilon, True)
# if len(approx) >= 3:
# simplified_contours.append(approx.astype(np.int32))
# result = image.copy()
# cv2.drawContours(result, simplified_contours, -1, color, thickness)
# if save_path:
# cv2.imwrite(save_path, result)
# return result, simplified_contours
# def find_corresponding_file(base_dir, base_name):
# """在目录下查找与 base_name 同名(不含扩展名)的图像文件(不限于 .png"""
# for ext in ['.png', '.tif', '.jpg', '.jpeg']:
# path = os.path.join(base_dir, base_name + ext)
# if os.path.exists(path):
# return path
# return None
# def process_all_images(image1_dir, image2_dir, pred_mask_dir, save_dir):
# os.makedirs(save_dir, exist_ok=True)
# mask_files = sorted([f for f in os.listdir(pred_mask_dir) if os.path.splitext(f)[1].lower() in ['.png', '.tif', '.jpg', '.jpeg']])
# for mask_file in tqdm(mask_files, desc="处理影像中"):
# base_name = os.path.splitext(mask_file)[0]
# image1_path = find_corresponding_file(image1_dir, base_name)
# image2_path = find_corresponding_file(image2_dir, base_name)
# mask_path = os.path.join(pred_mask_dir, mask_file)
# save_path_img1 = os.path.join(save_dir, f'{base_name}_image1_red.png')
# save_path_img2 = os.path.join(save_dir, f'{base_name}_image2_red.png')
# if not image1_path or not image2_path:
# print(f"[跳过] 缺失影像:{base_name}")
# continue
# # 先获取轮廓
# visualizer = PredictionVisualizer(image1_path, mask_path)
# _, contours = visualizer.draw_contours_on_image(save_path=None)
# for img_path, out_path in [(image1_path, save_path_img1), (image2_path, save_path_img2)]:
# visualizer_i = PredictionVisualizer(img_path, mask_path)
# visualizer_i.draw_contours_on_image(save_path=out_path)
# # === 设置路径并运行 ===
# if __name__ == '__main__':
# image1_dir = r'/media/data0/HL/2025-7-6/val/image1'
# image2_dir = r'/media/data0/HL/2025-7-6/val/image2'
# pred_mask_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save'
# save_dir = r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_red'
# process_all_images(image1_dir, image2_dir, pred_mask_dir, save_dir)