72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
import cv2
|
|
import numpy as np
|
|
from segment_anything import sam_model_registry, SamPredictor
|
|
|
|
points = [] # 存储点击的点
|
|
|
|
def mouse_callback(event, x, y, flags, param):
|
|
global points
|
|
if event == cv2.EVENT_LBUTTONDOWN:
|
|
if len(points) < 4:
|
|
points.append((x, y))
|
|
print(f"选中点 {len(points)}: {(x,y)}")
|
|
|
|
def draw_points(img, pts):
|
|
for p in pts:
|
|
cv2.circle(img, p, 5, (0, 0, 255), -1)
|
|
|
|
def get_prompt_points(img):
|
|
global points
|
|
clone = img.copy()
|
|
cv2.namedWindow("请用鼠标左键依次点击4个点形成提示")
|
|
cv2.setMouseCallback("请用鼠标左键依次点击4个点形成提示", mouse_callback)
|
|
|
|
while True:
|
|
disp = clone.copy()
|
|
draw_points(disp, points)
|
|
cv2.imshow("请用鼠标左键依次点击4个点形成提示", disp)
|
|
key = cv2.waitKey(1) & 0xFF
|
|
if key == 27: # ESC 退出
|
|
points = []
|
|
break
|
|
if len(points) == 4:
|
|
cv2.destroyAllWindows()
|
|
return points
|
|
cv2.destroyAllWindows()
|
|
return None
|
|
|
|
def detect_with_multiple_points(frame, predictor, prompt_points):
|
|
predictor.set_image(frame)
|
|
pts_np = np.array(prompt_points)
|
|
labels = np.ones(len(prompt_points), dtype=np.int32) # 全是前景点
|
|
masks, _, _ = predictor.predict(point_coords=pts_np, point_labels=labels, multimask_output=False)
|
|
mask = masks[0].astype(np.uint8)
|
|
return mask
|
|
|
|
def sam2_damage(image_path,checkpoint_path,output_path):
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
print("图片读取失败")
|
|
exit(1)
|
|
|
|
model = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
|
|
predictor = SamPredictor(model)
|
|
|
|
prompt_points = get_prompt_points(img)
|
|
if prompt_points is None:
|
|
print("未选取点,退出")
|
|
exit(0)
|
|
|
|
print("选取点:", prompt_points)
|
|
|
|
mask = detect_with_multiple_points(img, predictor, prompt_points)
|
|
masked_img = cv2.bitwise_and(img, img, mask=mask*255)
|
|
|
|
# 显示结果
|
|
cv2.imshow("分割结果", masked_img)
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
cv2.imwrite(output_path, masked_img)
|
|
|
|
if __name__ == "__main__":
|
|
sam2_damage("DJI_20250805103954_1962_V.jpeg", "sam_vit_h_4b8939.pth", "output.jpg") |