import cv2 import numpy as np from segment_anything import sam_model_registry, SamPredictor points = [] # 存储点击的点 def mouse_callback(event, x, y, flags, param): global points if event == cv2.EVENT_LBUTTONDOWN: if len(points) < 4: points.append((x, y)) print(f"选中点 {len(points)}: {(x,y)}") def draw_points(img, pts): for p in pts: cv2.circle(img, p, 5, (0, 0, 255), -1) def get_prompt_points(img): global points clone = img.copy() cv2.namedWindow("请用鼠标左键依次点击4个点形成提示") cv2.setMouseCallback("请用鼠标左键依次点击4个点形成提示", mouse_callback) while True: disp = clone.copy() draw_points(disp, points) cv2.imshow("请用鼠标左键依次点击4个点形成提示", disp) key = cv2.waitKey(1) & 0xFF if key == 27: # ESC 退出 points = [] break if len(points) == 4: cv2.destroyAllWindows() return points cv2.destroyAllWindows() return None def detect_with_multiple_points(frame, predictor, prompt_points): predictor.set_image(frame) pts_np = np.array(prompt_points) labels = np.ones(len(prompt_points), dtype=np.int32) # 全是前景点 masks, _, _ = predictor.predict(point_coords=pts_np, point_labels=labels, multimask_output=False) mask = masks[0].astype(np.uint8) return mask def sam2_damage(image_path,checkpoint_path,output_path): img = cv2.imread(image_path) if img is None: print("图片读取失败") exit(1) model = sam_model_registry["vit_h"](checkpoint=checkpoint_path) predictor = SamPredictor(model) prompt_points = get_prompt_points(img) if prompt_points is None: print("未选取点,退出") exit(0) print("选取点:", prompt_points) mask = detect_with_multiple_points(img, predictor, prompt_points) masked_img = cv2.bitwise_and(img, img, mask=mask*255) # 显示结果 cv2.imshow("分割结果", masked_img) cv2.waitKey(0) cv2.destroyAllWindows() cv2.imwrite(output_path, masked_img) if __name__ == "__main__": sam2_damage("DJI_20250805103954_1962_V.jpeg", "sam_vit_h_4b8939.pth", "output.jpg")