ai_project_v1/pic_detect.py

300 lines
13 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ctypes
import datetime
import json
import os
import time
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from cv_back_video import extract_box_details, cal_tricker_results
from middleware.minio_util import upload_file, downFile
from middleware.read_pic_metadata import read_dji_exif_to_dict
from mqtt_pub import MQTTClient
from ultralytics import solutions
from yolo.detect.yolo11_det_pic_trt import YoLo11TRT
mqtt_client = None
# MQTT 代理地址和端口
# broker = "112.44.103.230" # 公共 MQTT 代理(免费)
broker = "8.137.54.85" # 公共 MQTT 代理(免费)
port = 1883 # MQTT 默认端口
# 主题
topic = "thing/product/ai/events"
def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
task_id, model_func_id_list, para, s3_url_list, m1, cls, chinese_label, confidence,
use_fp16=False):
global mqtt_client
ov_model = None
try:
# 设置环境变量提高YOLO性能
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 导入必要的库
import torch
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
torch.cuda.empty_cache()
# 预加载YOLO模型
print("预加载YOLO模型...")
max_retries = 3
retry_count = 0
model_params = {}
# 尝试检测YOLO版本以适配不同版本的参数
try:
# 先简单尝试加载模型,不带任何参数
test_model = YOLO(m1)
# 如果成功,检查可用的参数
if hasattr(test_model, "task"):
model_params["task"] = "detect" # 指定任务类型
# 检查是否支持half精度
if torch.cuda.is_available():
model_params["half"] = True
# 检查是否支持verbose参数
import inspect
if "verbose" in inspect.signature(YOLO.__init__).parameters:
model_params["verbose"] = False
print(f"检测到支持的YOLO参数: {model_params}")
except Exception as e:
print(f"参数检测失败,将使用默认参数: {e}")
model_params = {}
while retry_count < max_retries:
try:
ov_model = YOLO(m1, **model_params)
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
for _ in range(3):
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
print("YOLO模型加载成功并预热完成")
break
except Exception as e:
retry_count += 1
print(f"YOLO模型加载失败 (尝试 {retry_count}/{max_retries}): {e}")
if "got an unexpected keyword argument" in str(e) and model_params:
param_name = str(e).split("'")[-2]
if param_name in model_params:
del model_params[param_name]
time.sleep(2)
if ov_model is None:
raise Exception("无法加载YOLO模型达到最大重试次数")
# 优化模型推理参数
ov_model.conf = confidence
# 将模型移到设备GPU或CPU
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
try:
ov_model.to(device)
# 调整批处理大小为1减少内存占用
if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
ov_model.args.batch = 1
# 启用半精度
if use_fp16 and device.startswith('cuda') and hasattr(ov_model, 'model'):
try:
ov_model.model = ov_model.model.half()
print("✅ 启用半精度 FP16 模式")
except Exception as half_err:
print(f"⚠️ 半精度转换失败: {half_err}")
else:
print(" 半精度模式未启用")
except Exception as e:
print(f"⚠️ 模型设备配置警告: {e}")
mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic)
# solutions默认会在图上绘制一个方框很烦
line_points = [(-10, -10), (-10, -10)]
# 伪代码后续识别类型需要跟ai_model_list 的id关联
# cls = [0, 1, 2, 3, 4]
# 这里参数show=True一旦通过rest接口重启程序会导致线程锁住
counter = solutions.ObjectCounter(show=False, region=line_points, model=ov_model, classes=cls)
# Model processing setup
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
try:
counter.model.to(device)
if hasattr(counter.model, 'args') and hasattr(counter.model.args, 'batch'):
counter.model.args.batch = 1
# 启用半精度
if use_fp16 and device.startswith('cuda') and hasattr(counter.model, 'model'):
try:
counter.model.model = counter.model.model.half()
print("✅ 启用半精度 FP16 模式")
except Exception as half_err:
print(f"⚠️ 半精度转换失败: {half_err}")
else:
print(" 半精度模式未启用")
except Exception as e:
print(f"⚠️ 模型设备配置警告: {e}")
error_count = 0
# Get class names from the model
class_names = counter.model.names if hasattr(counter.model, 'names') else {i: str(i) for i in range(1000)}
local_func_cache = {
"func_100000": None,
"func_100004": None, # 存储缓存缓存人员track_id
"func_100006": None # 存储缓存缓存车辆track_id
}
try:
for pic_url in s3_url_list:
image_path = downFile(pic_url)
print("1")
frame = cv2.imread(image_path)
# 读取照片元信息,包括经纬度、飞行姿态
image_exif_dict = read_dji_exif_to_dict(image_path)
try:
frame_copy = frame.copy()
results = counter(frame)
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
model_func_id_list,
local_func_cache, para, cls, chinese_label,
model_func_id_list[0])
except Exception as e:
print(f"处理帧错误: {e}")
error_count += 1
if error_count >= 5:
print(f"连续处理错误达到5次 ,正在停止处理...")
break
# 提取边界框统计信息
# box_detail = extract_box_details(last_results, model_id)
# 保存结果图片
local_path = f"output_frames/{os.path.basename(image_path)}"
success = cv2.imwrite(local_path, annotated_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
if success:
print(f"图像已保存到: {local_path}")
else:
print("错误: 图像保存失败")
date_str = str(datetime.datetime.now().timestamp())
minio_path_before, file_type_before = upload_file(image_path, None)
minio_path, file_type = upload_file(local_path, date_str)
os.remove(local_path)
message = {
"task_id": task_id,
"minio": {
"minio_path": minio_path,
"file_type": file_type,
"original_path": minio_path_before
},
"box_detail": box_result,
"image_exif_dict": image_exif_dict
}
# print("16")
# json_message = json.dumps(message)
# 将字典转换为 JSON 字符串,使用双引号,并保留中文字符
json_message = json.dumps(message, indent=4, ensure_ascii=False)
mqtt_client.publish_message(json_message)
time.sleep(0.1)
if os.path.exists(minio_path):
os.remove(minio_path)
if os.path.exists(minio_path_before):
os.remove(minio_path_before)
except Exception as infer_err:
print(f"推理错误: {infer_err}")
except Exception as infer_err:
print(f"pic_detect 推理错误: {infer_err}")
def pic_detect_func_trt(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
task_id, s3_url_list, engine_file_path,
PLUGIN_LIBRARY, confidence, categories ):
global mqtt_client
try:
print(f"mqtt_client {mqtt_pub_ip} {mqtt_pub_port} {mqtt_pub_topic}")
mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic)
print("mqtt")
# # 设置引擎和插件路径
# engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
# 加载插件
ctypes.CDLL(PLUGIN_LIBRARY)
print(f"PLUGIN_LIBRARY{PLUGIN_LIBRARY}")
# 初始化推理器
yolo11_wrapper = YoLo11TRT(engine_file_path)
print(f"engine_file_path{engine_file_path} {len(s3_url_list)}")
try:
for pic_url in s3_url_list:
image_path = downFile(pic_url)
print("1")
frame = cv2.imread(image_path)
print("2")
# # 读取照片元信息,包括经纬度、飞行姿态
# image_exif_dict = read_dji_exif_to_dict(image_path)
batch_image_raw, result_box_list, result_scores_list, result_classid_list, infer_time = yolo11_wrapper.infer(
[frame], confidence, categories)
detected_frame = batch_image_raw[0]
# 提取边界框统计信息
# box_detail = extract_box_details(last_results, model_id)
# 保存结果图片
local_path = f"output_frames/{os.path.basename(image_path)}"
success = cv2.imwrite(local_path, detected_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
if success:
print(f"图像已保存到: {local_path}")
else:
print("错误: 图像保存失败")
date_str = str(datetime.datetime.now().timestamp())
minio_path_before, file_type_before = upload_file(image_path, None)
minio_path, file_type = upload_file(local_path, date_str)
os.remove(local_path)
message = {
"task_id": task_id,
"minio": {
"minio_path": minio_path,
"file_type": file_type,
"original_path": minio_path_before
},
"box_detail": {
"result_box_list": result_box_list,
"result_scores_list": result_scores_list,
"result_classid_list": result_classid_list
}
# "box_detail": {
# "result_box_list": result_box_list.tolist() if isinstance(result_box_list,
# np.ndarray) else result_box_list,
# "result_scores_list": result_scores_list.tolist() if isinstance(result_scores_list,
# np.ndarray) else result_scores_list,
# "result_classid_list": result_classid_list.tolist() if isinstance(result_classid_list,
# np.ndarray) else result_classid_list
# }
}
json_message = json.dumps(message, indent=4, ensure_ascii=False)
mqtt_client.publish_message(json_message)
time.sleep(0.1)
if os.path.exists(minio_path):
os.remove(minio_path)
if os.path.exists(minio_path_before):
os.remove(minio_path_before)
except Exception as infer_err:
print(f"推理错误: {infer_err}")
finally:
yolo11_wrapper.destroy()
except Exception as infer_err:
print(f"pic_detect 推理错误: {infer_err}")