101 lines
3.7 KiB
Python
101 lines
3.7 KiB
Python
import os
|
||
import shutil
|
||
import cv2
|
||
import collections
|
||
from ultralytics import YOLO
|
||
from miniohelp import downFile, upload_file, parse_minio_url # 确保你有这些工具函数
|
||
from minio import Minio
|
||
|
||
def process_images(yolo_model, image_list, class_filter, input_folder, output_folder, minio_info):
|
||
# 初始化 MinIO 客户端# 用配置字典初始化 Minio 客户端对象
|
||
# 清洗 endpoint,去掉 http:// 或 https:// 前缀
|
||
endpoint = minio_info["MinIOEndpoint"].replace("http://", "").replace("https://", "")
|
||
|
||
# 初始化 MinIO 客户端
|
||
minio = Minio(
|
||
endpoint=endpoint,
|
||
access_key=minio_info["MinIOAccessKey"],
|
||
secret_key=minio_info["MinIOSecretKey"],
|
||
secure=False
|
||
)
|
||
os.makedirs(input_folder, exist_ok=True)
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
model = YOLO(yolo_model)
|
||
class_ids_filter = [int(cls) for cls in class_filter.split(",")] if class_filter else None
|
||
output_image_list = []
|
||
|
||
for item in image_list:
|
||
img_id = item["id"]
|
||
img_url = item["path"]
|
||
|
||
# 解析 MinIO 地址
|
||
if img_url.startswith("http"):
|
||
bucket_name, img_path = parse_minio_url(img_url)
|
||
else:
|
||
bucket_name, img_path = "default-bucket", img_url
|
||
|
||
try:
|
||
# 下载原图到本地
|
||
local_input_path = os.path.join(input_folder, os.path.basename(img_path))
|
||
downFile(minio, img_path, bucket_name, local_input_path)
|
||
|
||
# 读取图像
|
||
image = cv2.imread(local_input_path)
|
||
if image is None:
|
||
raise ValueError(f"无法读取图像: {local_input_path}")
|
||
|
||
# YOLO 检测
|
||
results = model.predict(image,
|
||
classes=class_ids_filter,
|
||
conf=0.5,
|
||
iou = 0.111,
|
||
show_labels = False,)
|
||
result = results[0]
|
||
|
||
# 统计类别数
|
||
class_counts = collections.Counter(result.boxes.cls.cpu().numpy().astype(int)) if result.boxes is not None else {}
|
||
filtered_class_counts = {k: v for k, v in class_counts.items() if k in class_ids_filter}
|
||
|
||
# 转换所有的 numpy.int64 为 Python 的 int 类型
|
||
detected_classes = [int(cls) for cls in filtered_class_counts.keys()]
|
||
detected_numbers = [int(num) for num in filtered_class_counts.values()]
|
||
aim = bool(detected_classes)
|
||
|
||
# 保存标注图像
|
||
annotated_image = result.plot(labels=False)
|
||
filename_no_ext, ext = os.path.splitext(os.path.basename(img_path))
|
||
output_filename = f"{filename_no_ext}_ai{ext}"
|
||
local_output_path = os.path.join(output_folder, output_filename)
|
||
cv2.imwrite(local_output_path, annotated_image)
|
||
|
||
# 上传标注图像到 MinIO
|
||
minio_path = upload_file(minio, local_output_path, bucket_name, os.path.dirname(img_path))
|
||
|
||
except Exception as e:
|
||
print(f"[错误] 处理失败 - {img_path},错误: {str(e)}")
|
||
detected_classes = []
|
||
detected_numbers = []
|
||
aim = False
|
||
output_filename = ""
|
||
minio_path = ""
|
||
|
||
output_image_list.append({
|
||
"id": img_id,
|
||
"minio_path":minio_path,
|
||
"aim": aim,
|
||
"class": detected_classes,
|
||
"number": detected_numbers
|
||
})
|
||
|
||
# 清理临时目录
|
||
shutil.rmtree(input_folder, ignore_errors=True)
|
||
shutil.rmtree(output_folder, ignore_errors=True)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "Detection completed",
|
||
"data": output_image_list
|
||
}
|
||
|