174 lines
5.8 KiB
Python
174 lines
5.8 KiB
Python
import os.path
|
||
|
||
import numpy as np
|
||
import cv2
|
||
|
||
# 根据输出结果输出json问价还有边界
|
||
def segment_instance_and_contour(seg_map, class_names):
|
||
"""
|
||
参数:
|
||
seg_map: HxW 的语义分割图像,像素值表示类别编号
|
||
class_names: 类别编号到类别名的映射,如 {0:'Background', 1:'Farmland', 2:'Forest'}
|
||
返回:
|
||
实例列表,每个包含id、类别名、轮廓点
|
||
"""
|
||
instance_id = 1
|
||
results = []
|
||
|
||
for cls_id in np.unique(seg_map):
|
||
if cls_id == 0: # 忽略背景
|
||
continue
|
||
mask = (seg_map == cls_id).astype(np.uint8)
|
||
|
||
# 找连通域(不同地块)
|
||
num_labels, labels = cv2.connectedComponents(mask)
|
||
|
||
for i in range(1, num_labels): # 第0是背景
|
||
instance_mask = (labels == i).astype(np.uint8)
|
||
|
||
# 找轮廓
|
||
contours, _ = cv2.findContours(instance_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
for contour in contours:
|
||
contour_list = contour.squeeze().tolist()
|
||
if len(contour_list) < 3:
|
||
continue # 过滤掉太小的碎片
|
||
results.append({
|
||
'id': instance_id,
|
||
'class': class_names.get(cls_id, f"Class{cls_id}"),
|
||
'contour': contour_list
|
||
})
|
||
instance_id += 1
|
||
|
||
return results
|
||
|
||
def draw_instances_with_ids(seg_map, results, save_path="instance_visual.png"):
|
||
"""
|
||
在原始图或灰度图上绘制每个地块的轮廓及其编号。
|
||
|
||
参数:
|
||
seg_map: 原始图像或分割图(HxW 或 HxWx3)
|
||
results: 上一步输出的地块信息列表
|
||
save_path: 保存路径
|
||
"""
|
||
if len(seg_map.shape) == 2: # 若为灰度图,转为BGR
|
||
vis_img = cv2.cvtColor(seg_map, cv2.COLOR_GRAY2BGR)
|
||
else:
|
||
vis_img = seg_map.copy()
|
||
# h, w = vis_img.shape[:2]
|
||
for item in results:
|
||
contour = np.array(item['contour']).reshape(-1, 1, 2).astype(np.int32)
|
||
|
||
# 随机颜色(根据类别也可指定颜色)
|
||
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
||
cv2.drawContours(vis_img, [contour], -1, color, 2)#线的宽度
|
||
|
||
# 计算重心位置以便放编号
|
||
M = cv2.moments(contour)
|
||
if M['m00'] != 0:
|
||
cx = int(M['m10'] / M['m00'])
|
||
cy = int(M['m01'] / M['m00'])
|
||
cv2.putText(vis_img, str(item['id']), (cx, cy),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)#序号的大小与线宽度
|
||
|
||
# # 替代 moments 重心方式,使用外接矩形中心
|
||
# x, y, w, h = cv2.boundingRect(contour)
|
||
# cx = x + w // 2
|
||
# cy = y + h // 2
|
||
# cv2.putText(vis_img, str(item['id']), (cx, cy),
|
||
# cv2.FONT_HERSHEY_SIMPLEX, 3, color, 2, cv2.LINE_AA)
|
||
|
||
# # 创建该块的 instance mask,用于质心计算
|
||
# instance_mask = np.zeros((h, w), dtype=np.uint8)
|
||
# cv2.drawContours(instance_mask, [contour], -1, 1, -1)
|
||
#
|
||
# # 计算质心(内部像素点的几何中心)
|
||
# ys, xs = np.where(instance_mask > 0)
|
||
# if len(xs) > 0 and len(ys) > 0:
|
||
# cx = int(np.mean(xs))
|
||
# cy = int(np.mean(ys))
|
||
# cv2.putText(vis_img, str(item['id']), (cx, cy),
|
||
# cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
|
||
|
||
cv2.imwrite(save_path, vis_img)
|
||
print(f"✅ 可视化图保存至: {save_path}")
|
||
|
||
def color_map_to_label(seg_img, palette):
|
||
"""
|
||
将彩色语义图映射为类别ID图
|
||
|
||
参数:
|
||
seg_img: HxWx3 彩色图像(例如读取的png)
|
||
palette: 类别调色板,如 [[0,0,0], [139,69,19], ...]
|
||
|
||
返回:
|
||
label_map: HxW 图像,每个像素为类别ID(int)
|
||
"""
|
||
h, w, _ = seg_img.shape
|
||
label_map = np.zeros((h, w), dtype=np.uint8)
|
||
|
||
# 将调色板转换为字典
|
||
color2label = {tuple(color): idx for idx, color in enumerate(palette)}
|
||
|
||
# 对每个像素映射
|
||
for color, label in color2label.items():
|
||
mask = np.all(seg_img == color, axis=-1)
|
||
label_map[mask] = label
|
||
|
||
return label_map
|
||
|
||
# 定义类别调色板和类别名
|
||
PALETTE = [
|
||
[0, 0, 0], # background
|
||
[139, 69, 19], # barren
|
||
[0, 255, 0], # forest
|
||
[255, 255, 0], # farmland
|
||
[0, 0, 255], # water
|
||
[128, 128, 128], # road
|
||
[0, 255, 255] # building
|
||
]
|
||
CLASS_NAMES = {
|
||
0: 'background',
|
||
1: 'barren',
|
||
2: 'forest',
|
||
3: 'farmland',
|
||
4: 'water',
|
||
5: 'road',
|
||
6: 'building'
|
||
}
|
||
|
||
|
||
def draw_json_boundary(pic_url):
|
||
|
||
|
||
# 读取彩色分割图
|
||
seg_color = cv2.imread(pic_url) # 注意是 BGR 格式
|
||
# seg_color = cv2.imread(r"D:\project\UAV_model\prediction_results\unetformer_UAV_6000X4000\patch_0028.png") # 注意是 BGR 格式
|
||
# seg_color = cv2.imread(r"J:\uhr\pytools\numer_coordinate\patch_0047_rgb.png") # 注意是 BGR 格式
|
||
seg_color = cv2.cvtColor(seg_color, cv2.COLOR_BGR2RGB)
|
||
|
||
# 转为类别索引图
|
||
label_map = color_map_to_label(seg_color, PALETTE)
|
||
|
||
# 继续后续处理
|
||
results = segment_instance_and_contour(label_map, CLASS_NAMES)
|
||
dir_name=os.path.dirname(pic_url)
|
||
final_vis_png_name=os.path.basename(pic_url)+"final_vis.png"
|
||
final_vis_png_path=os.path.join(dir_name,final_vis_png_name) #输出大本地的只有边缘线条
|
||
|
||
instance_results_json_name=os.path.basename(pic_url)+"instance_results.json"
|
||
instance_results_json_path=os.path.join(dir_name,instance_results_json_name) # 输出到本地的,分割的边缘的json文件
|
||
draw_instances_with_ids(label_map, results, save_path=final_vis_png_path)
|
||
|
||
|
||
|
||
# 保存为 JSON(可选)
|
||
import json
|
||
with open(instance_results_json_path, "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
|
||
return final_vis_png_path,instance_results_json_path
|
||
|
||
|
||
|
||
|