300 lines
13 KiB
Python
300 lines
13 KiB
Python
import ctypes
|
||
import datetime
|
||
import json
|
||
import os
|
||
import time
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from ultralytics import YOLO
|
||
|
||
from cv_back_video import extract_box_details, cal_tricker_results
|
||
from middleware.minio_util import upload_file, downFile
|
||
from middleware.read_pic_metadata import read_dji_exif_to_dict
|
||
from mqtt_pub import MQTTClient
|
||
from ultralytics import solutions
|
||
|
||
from yolo.detect.yolo11_det_pic_trt import YoLo11TRT
|
||
|
||
mqtt_client = None
|
||
# MQTT 代理地址和端口
|
||
# broker = "112.44.103.230" # 公共 MQTT 代理(免费)
|
||
broker = "8.137.54.85" # 公共 MQTT 代理(免费)
|
||
port = 1883 # MQTT 默认端口
|
||
# 主题
|
||
topic = "thing/product/ai/events"
|
||
|
||
|
||
def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||
task_id, model_func_id_list, para, s3_url_list, m1, cls, chinese_label, confidence,
|
||
use_fp16=False):
|
||
global mqtt_client
|
||
ov_model = None
|
||
|
||
try:
|
||
# 设置环境变量,提高YOLO性能
|
||
os.environ["OMP_NUM_THREADS"] = "4"
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
||
# 导入必要的库
|
||
import torch
|
||
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
|
||
if torch.cuda.is_available():
|
||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||
torch.cuda.empty_cache()
|
||
|
||
# 预加载YOLO模型
|
||
print("预加载YOLO模型...")
|
||
max_retries = 3
|
||
retry_count = 0
|
||
model_params = {}
|
||
# 尝试检测YOLO版本以适配不同版本的参数
|
||
try:
|
||
# 先简单尝试加载模型,不带任何参数
|
||
test_model = YOLO(m1)
|
||
# 如果成功,检查可用的参数
|
||
if hasattr(test_model, "task"):
|
||
model_params["task"] = "detect" # 指定任务类型
|
||
|
||
# 检查是否支持half精度
|
||
if torch.cuda.is_available():
|
||
model_params["half"] = True
|
||
|
||
# 检查是否支持verbose参数
|
||
import inspect
|
||
if "verbose" in inspect.signature(YOLO.__init__).parameters:
|
||
model_params["verbose"] = False
|
||
|
||
print(f"检测到支持的YOLO参数: {model_params}")
|
||
except Exception as e:
|
||
print(f"参数检测失败,将使用默认参数: {e}")
|
||
model_params = {}
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
ov_model = YOLO(m1, **model_params)
|
||
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
|
||
for _ in range(3):
|
||
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
|
||
print("YOLO模型加载成功并预热完成")
|
||
break
|
||
except Exception as e:
|
||
retry_count += 1
|
||
print(f"YOLO模型加载失败 (尝试 {retry_count}/{max_retries}): {e}")
|
||
if "got an unexpected keyword argument" in str(e) and model_params:
|
||
param_name = str(e).split("'")[-2]
|
||
if param_name in model_params:
|
||
del model_params[param_name]
|
||
time.sleep(2)
|
||
|
||
if ov_model is None:
|
||
raise Exception("无法加载YOLO模型,达到最大重试次数")
|
||
|
||
# 优化模型推理参数
|
||
ov_model.conf = confidence
|
||
|
||
# 将模型移到设备(GPU或CPU)
|
||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||
try:
|
||
ov_model.to(device)
|
||
# 调整批处理大小为1,减少内存占用
|
||
if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
|
||
ov_model.args.batch = 1
|
||
# 启用半精度
|
||
if use_fp16 and device.startswith('cuda') and hasattr(ov_model, 'model'):
|
||
try:
|
||
ov_model.model = ov_model.model.half()
|
||
print("✅ 启用半精度 FP16 模式")
|
||
except Exception as half_err:
|
||
print(f"⚠️ 半精度转换失败: {half_err}")
|
||
else:
|
||
print("ℹ️ 半精度模式未启用")
|
||
except Exception as e:
|
||
print(f"⚠️ 模型设备配置警告: {e}")
|
||
|
||
mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic)
|
||
|
||
# solutions默认会在图上绘制一个方框,很烦
|
||
line_points = [(-10, -10), (-10, -10)]
|
||
# 伪代码,后续识别类型需要跟ai_model_list 的id关联
|
||
# cls = [0, 1, 2, 3, 4]
|
||
# 这里参数show=True,一旦通过rest接口重启程序,会导致线程锁住
|
||
|
||
counter = solutions.ObjectCounter(show=False, region=line_points, model=ov_model, classes=cls)
|
||
|
||
# Model processing setup
|
||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||
try:
|
||
counter.model.to(device)
|
||
if hasattr(counter.model, 'args') and hasattr(counter.model.args, 'batch'):
|
||
counter.model.args.batch = 1
|
||
# 启用半精度
|
||
if use_fp16 and device.startswith('cuda') and hasattr(counter.model, 'model'):
|
||
try:
|
||
counter.model.model = counter.model.model.half()
|
||
print("✅ 启用半精度 FP16 模式")
|
||
except Exception as half_err:
|
||
print(f"⚠️ 半精度转换失败: {half_err}")
|
||
else:
|
||
print("ℹ️ 半精度模式未启用")
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 模型设备配置警告: {e}")
|
||
error_count = 0
|
||
|
||
# Get class names from the model
|
||
class_names = counter.model.names if hasattr(counter.model, 'names') else {i: str(i) for i in range(1000)}
|
||
|
||
local_func_cache = {
|
||
"func_100000": None,
|
||
"func_100004": None, # 存储缓存,缓存人员track_id
|
||
"func_100006": None # 存储缓存,缓存车辆track_id
|
||
|
||
}
|
||
|
||
try:
|
||
for pic_url in s3_url_list:
|
||
image_path = downFile(pic_url)
|
||
print("1")
|
||
frame = cv2.imread(image_path)
|
||
|
||
# 读取照片元信息,包括经纬度、飞行姿态
|
||
image_exif_dict = read_dji_exif_to_dict(image_path)
|
||
|
||
try:
|
||
frame_copy = frame.copy()
|
||
results = counter(frame)
|
||
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
|
||
model_func_id_list,
|
||
local_func_cache, para, cls, chinese_label,
|
||
model_func_id_list[0])
|
||
|
||
except Exception as e:
|
||
print(f"处理帧错误: {e}")
|
||
error_count += 1
|
||
if error_count >= 5:
|
||
print(f"连续处理错误达到5次 ,正在停止处理...")
|
||
break
|
||
|
||
# 提取边界框统计信息
|
||
# box_detail = extract_box_details(last_results, model_id)
|
||
# 保存结果图片
|
||
local_path = f"output_frames/{os.path.basename(image_path)}"
|
||
success = cv2.imwrite(local_path, annotated_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||
if success:
|
||
print(f"图像已保存到: {local_path}")
|
||
else:
|
||
print("错误: 图像保存失败")
|
||
date_str = str(datetime.datetime.now().timestamp())
|
||
|
||
minio_path_before, file_type_before = upload_file(image_path, None)
|
||
minio_path, file_type = upload_file(local_path, date_str)
|
||
os.remove(local_path)
|
||
message = {
|
||
"task_id": task_id,
|
||
"minio": {
|
||
"minio_path": minio_path,
|
||
"file_type": file_type,
|
||
"original_path": minio_path_before
|
||
},
|
||
"box_detail": box_result,
|
||
"image_exif_dict": image_exif_dict
|
||
}
|
||
# print("16")
|
||
# json_message = json.dumps(message)
|
||
# 将字典转换为 JSON 字符串,使用双引号,并保留中文字符
|
||
json_message = json.dumps(message, indent=4, ensure_ascii=False)
|
||
mqtt_client.publish_message(json_message)
|
||
time.sleep(0.1)
|
||
if os.path.exists(minio_path):
|
||
os.remove(minio_path)
|
||
if os.path.exists(minio_path_before):
|
||
os.remove(minio_path_before)
|
||
except Exception as infer_err:
|
||
print(f"推理错误: {infer_err}")
|
||
except Exception as infer_err:
|
||
print(f"pic_detect 推理错误: {infer_err}")
|
||
|
||
|
||
def pic_detect_func_trt(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||
task_id, s3_url_list, engine_file_path,
|
||
PLUGIN_LIBRARY, confidence, categories ):
|
||
global mqtt_client
|
||
|
||
try:
|
||
print(f"mqtt_client {mqtt_pub_ip} {mqtt_pub_port} {mqtt_pub_topic}")
|
||
mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic)
|
||
print("mqtt")
|
||
# # 设置引擎和插件路径
|
||
# engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
|
||
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
|
||
|
||
# 加载插件
|
||
ctypes.CDLL(PLUGIN_LIBRARY)
|
||
print(f"PLUGIN_LIBRARY{PLUGIN_LIBRARY}")
|
||
# 初始化推理器
|
||
yolo11_wrapper = YoLo11TRT(engine_file_path)
|
||
print(f"engine_file_path{engine_file_path} {len(s3_url_list)}")
|
||
try:
|
||
for pic_url in s3_url_list:
|
||
image_path = downFile(pic_url)
|
||
print("1")
|
||
frame = cv2.imread(image_path)
|
||
print("2")
|
||
# # 读取照片元信息,包括经纬度、飞行姿态
|
||
# image_exif_dict = read_dji_exif_to_dict(image_path)
|
||
batch_image_raw, result_box_list, result_scores_list, result_classid_list, infer_time = yolo11_wrapper.infer(
|
||
[frame], confidence, categories)
|
||
detected_frame = batch_image_raw[0]
|
||
|
||
# 提取边界框统计信息
|
||
# box_detail = extract_box_details(last_results, model_id)
|
||
# 保存结果图片
|
||
local_path = f"output_frames/{os.path.basename(image_path)}"
|
||
success = cv2.imwrite(local_path, detected_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||
if success:
|
||
print(f"图像已保存到: {local_path}")
|
||
else:
|
||
print("错误: 图像保存失败")
|
||
date_str = str(datetime.datetime.now().timestamp())
|
||
|
||
minio_path_before, file_type_before = upload_file(image_path, None)
|
||
minio_path, file_type = upload_file(local_path, date_str)
|
||
os.remove(local_path)
|
||
message = {
|
||
"task_id": task_id,
|
||
"minio": {
|
||
"minio_path": minio_path,
|
||
"file_type": file_type,
|
||
"original_path": minio_path_before
|
||
},
|
||
"box_detail": {
|
||
"result_box_list": result_box_list,
|
||
"result_scores_list": result_scores_list,
|
||
"result_classid_list": result_classid_list
|
||
}
|
||
# "box_detail": {
|
||
# "result_box_list": result_box_list.tolist() if isinstance(result_box_list,
|
||
# np.ndarray) else result_box_list,
|
||
# "result_scores_list": result_scores_list.tolist() if isinstance(result_scores_list,
|
||
# np.ndarray) else result_scores_list,
|
||
# "result_classid_list": result_classid_list.tolist() if isinstance(result_classid_list,
|
||
# np.ndarray) else result_classid_list
|
||
# }
|
||
}
|
||
|
||
json_message = json.dumps(message, indent=4, ensure_ascii=False)
|
||
mqtt_client.publish_message(json_message)
|
||
time.sleep(0.1)
|
||
if os.path.exists(minio_path):
|
||
os.remove(minio_path)
|
||
if os.path.exists(minio_path_before):
|
||
os.remove(minio_path_before)
|
||
except Exception as infer_err:
|
||
print(f"推理错误: {infer_err}")
|
||
finally:
|
||
yolo11_wrapper.destroy()
|
||
except Exception as infer_err:
|
||
print(f"pic_detect 推理错误: {infer_err}")
|