300 lines
13 KiB
Python
300 lines
13 KiB
Python
|
|
import ctypes
|
|||
|
|
import datetime
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
from cv_back_video import extract_box_details, cal_tricker_results
|
|||
|
|
from middleware.minio_util import upload_file, downFile
|
|||
|
|
from middleware.read_pic_metadata import read_dji_exif_to_dict
|
|||
|
|
from mqtt_pub import MQTTClient
|
|||
|
|
from ultralytics import solutions
|
|||
|
|
|
|||
|
|
from yolo.detect.yolo11_det_pic_trt import YoLo11TRT
|
|||
|
|
|
|||
|
|
mqtt_client = None
|
|||
|
|
# MQTT 代理地址和端口
|
|||
|
|
# broker = "112.44.103.230" # 公共 MQTT 代理(免费)
|
|||
|
|
broker = "8.137.54.85" # 公共 MQTT 代理(免费)
|
|||
|
|
port = 1883 # MQTT 默认端口
|
|||
|
|
# 主题
|
|||
|
|
topic = "thing/product/ai/events"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
|||
|
|
task_id, model_func_id_list, para, s3_url_list, m1, cls, chinese_label, confidence,
|
|||
|
|
use_fp16=False):
|
|||
|
|
global mqtt_client
|
|||
|
|
ov_model = None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 设置环境变量,提高YOLO性能
|
|||
|
|
os.environ["OMP_NUM_THREADS"] = "4"
|
|||
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|||
|
|
|
|||
|
|
# 导入必要的库
|
|||
|
|
import torch
|
|||
|
|
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|||
|
|
torch.cuda.empty_cache()
|
|||
|
|
|
|||
|
|
# 预加载YOLO模型
|
|||
|
|
print("预加载YOLO模型...")
|
|||
|
|
max_retries = 3
|
|||
|
|
retry_count = 0
|
|||
|
|
model_params = {}
|
|||
|
|
# 尝试检测YOLO版本以适配不同版本的参数
|
|||
|
|
try:
|
|||
|
|
# 先简单尝试加载模型,不带任何参数
|
|||
|
|
test_model = YOLO(m1)
|
|||
|
|
# 如果成功,检查可用的参数
|
|||
|
|
if hasattr(test_model, "task"):
|
|||
|
|
model_params["task"] = "detect" # 指定任务类型
|
|||
|
|
|
|||
|
|
# 检查是否支持half精度
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
model_params["half"] = True
|
|||
|
|
|
|||
|
|
# 检查是否支持verbose参数
|
|||
|
|
import inspect
|
|||
|
|
if "verbose" in inspect.signature(YOLO.__init__).parameters:
|
|||
|
|
model_params["verbose"] = False
|
|||
|
|
|
|||
|
|
print(f"检测到支持的YOLO参数: {model_params}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"参数检测失败,将使用默认参数: {e}")
|
|||
|
|
model_params = {}
|
|||
|
|
|
|||
|
|
while retry_count < max_retries:
|
|||
|
|
try:
|
|||
|
|
ov_model = YOLO(m1, **model_params)
|
|||
|
|
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
|
|||
|
|
for _ in range(3):
|
|||
|
|
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
|
|||
|
|
print("YOLO模型加载成功并预热完成")
|
|||
|
|
break
|
|||
|
|
except Exception as e:
|
|||
|
|
retry_count += 1
|
|||
|
|
print(f"YOLO模型加载失败 (尝试 {retry_count}/{max_retries}): {e}")
|
|||
|
|
if "got an unexpected keyword argument" in str(e) and model_params:
|
|||
|
|
param_name = str(e).split("'")[-2]
|
|||
|
|
if param_name in model_params:
|
|||
|
|
del model_params[param_name]
|
|||
|
|
time.sleep(2)
|
|||
|
|
|
|||
|
|
if ov_model is None:
|
|||
|
|
raise Exception("无法加载YOLO模型,达到最大重试次数")
|
|||
|
|
|
|||
|
|
# 优化模型推理参数
|
|||
|
|
ov_model.conf = confidence
|
|||
|
|
|
|||
|
|
# 将模型移到设备(GPU或CPU)
|
|||
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|||
|
|
try:
|
|||
|
|
ov_model.to(device)
|
|||
|
|
# 调整批处理大小为1,减少内存占用
|
|||
|
|
if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
|
|||
|
|
ov_model.args.batch = 1
|
|||
|
|
# 启用半精度
|
|||
|
|
if use_fp16 and device.startswith('cuda') and hasattr(ov_model, 'model'):
|
|||
|
|
try:
|
|||
|
|
ov_model.model = ov_model.model.half()
|
|||
|
|
print("✅ 启用半精度 FP16 模式")
|
|||
|
|
except Exception as half_err:
|
|||
|
|
print(f"⚠️ 半精度转换失败: {half_err}")
|
|||
|
|
else:
|
|||
|
|
print("ℹ️ 半精度模式未启用")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"⚠️ 模型设备配置警告: {e}")
|
|||
|
|
|
|||
|
|
mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic)
|
|||
|
|
|
|||
|
|
# solutions默认会在图上绘制一个方框,很烦
|
|||
|
|
line_points = [(-10, -10), (-10, -10)]
|
|||
|
|
# 伪代码,后续识别类型需要跟ai_model_list 的id关联
|
|||
|
|
# cls = [0, 1, 2, 3, 4]
|
|||
|
|
# 这里参数show=True,一旦通过rest接口重启程序,会导致线程锁住
|
|||
|
|
|
|||
|
|
counter = solutions.ObjectCounter(show=False, region=line_points, model=ov_model, classes=cls)
|
|||
|
|
|
|||
|
|
# Model processing setup
|
|||
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|||
|
|
try:
|
|||
|
|
counter.model.to(device)
|
|||
|
|
if hasattr(counter.model, 'args') and hasattr(counter.model.args, 'batch'):
|
|||
|
|
counter.model.args.batch = 1
|
|||
|
|
# 启用半精度
|
|||
|
|
if use_fp16 and device.startswith('cuda') and hasattr(counter.model, 'model'):
|
|||
|
|
try:
|
|||
|
|
counter.model.model = counter.model.model.half()
|
|||
|
|
print("✅ 启用半精度 FP16 模式")
|
|||
|
|
except Exception as half_err:
|
|||
|
|
print(f"⚠️ 半精度转换失败: {half_err}")
|
|||
|
|
else:
|
|||
|
|
print("ℹ️ 半精度模式未启用")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"⚠️ 模型设备配置警告: {e}")
|
|||
|
|
error_count = 0
|
|||
|
|
|
|||
|
|
# Get class names from the model
|
|||
|
|
class_names = counter.model.names if hasattr(counter.model, 'names') else {i: str(i) for i in range(1000)}
|
|||
|
|
|
|||
|
|
local_func_cache = {
|
|||
|
|
"func_100000": None,
|
|||
|
|
"func_100004": None, # 存储缓存,缓存人员track_id
|
|||
|
|
"func_100006": None # 存储缓存,缓存车辆track_id
|
|||
|
|
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for pic_url in s3_url_list:
|
|||
|
|
image_path = downFile(pic_url)
|
|||
|
|
print("1")
|
|||
|
|
frame = cv2.imread(image_path)
|
|||
|
|
|
|||
|
|
# 读取照片元信息,包括经纬度、飞行姿态
|
|||
|
|
image_exif_dict = read_dji_exif_to_dict(image_path)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
frame_copy = frame.copy()
|
|||
|
|
results = counter(frame)
|
|||
|
|
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
|
|||
|
|
model_func_id_list,
|
|||
|
|
local_func_cache, para, cls, chinese_label,
|
|||
|
|
model_func_id_list[0])
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"处理帧错误: {e}")
|
|||
|
|
error_count += 1
|
|||
|
|
if error_count >= 5:
|
|||
|
|
print(f"连续处理错误达到5次 ,正在停止处理...")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 提取边界框统计信息
|
|||
|
|
# box_detail = extract_box_details(last_results, model_id)
|
|||
|
|
# 保存结果图片
|
|||
|
|
local_path = f"output_frames/{os.path.basename(image_path)}"
|
|||
|
|
success = cv2.imwrite(local_path, annotated_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|||
|
|
if success:
|
|||
|
|
print(f"图像已保存到: {local_path}")
|
|||
|
|
else:
|
|||
|
|
print("错误: 图像保存失败")
|
|||
|
|
date_str = str(datetime.datetime.now().timestamp())
|
|||
|
|
|
|||
|
|
minio_path_before, file_type_before = upload_file(image_path, None)
|
|||
|
|
minio_path, file_type = upload_file(local_path, date_str)
|
|||
|
|
os.remove(local_path)
|
|||
|
|
message = {
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"minio": {
|
|||
|
|
"minio_path": minio_path,
|
|||
|
|
"file_type": file_type,
|
|||
|
|
"original_path": minio_path_before
|
|||
|
|
},
|
|||
|
|
"box_detail": box_result,
|
|||
|
|
"image_exif_dict": image_exif_dict
|
|||
|
|
}
|
|||
|
|
# print("16")
|
|||
|
|
# json_message = json.dumps(message)
|
|||
|
|
# 将字典转换为 JSON 字符串,使用双引号,并保留中文字符
|
|||
|
|
json_message = json.dumps(message, indent=4, ensure_ascii=False)
|
|||
|
|
mqtt_client.publish_message(json_message)
|
|||
|
|
time.sleep(0.1)
|
|||
|
|
if os.path.exists(minio_path):
|
|||
|
|
os.remove(minio_path)
|
|||
|
|
if os.path.exists(minio_path_before):
|
|||
|
|
os.remove(minio_path_before)
|
|||
|
|
except Exception as infer_err:
|
|||
|
|
print(f"推理错误: {infer_err}")
|
|||
|
|
except Exception as infer_err:
|
|||
|
|
print(f"pic_detect 推理错误: {infer_err}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def pic_detect_func_trt(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
|||
|
|
task_id, s3_url_list, engine_file_path,
|
|||
|
|
PLUGIN_LIBRARY, confidence, categories ):
|
|||
|
|
global mqtt_client
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
print(f"mqtt_client {mqtt_pub_ip} {mqtt_pub_port} {mqtt_pub_topic}")
|
|||
|
|
mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic)
|
|||
|
|
print("mqtt")
|
|||
|
|
# # 设置引擎和插件路径
|
|||
|
|
# engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
|
|||
|
|
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
|
|||
|
|
|
|||
|
|
# 加载插件
|
|||
|
|
ctypes.CDLL(PLUGIN_LIBRARY)
|
|||
|
|
print(f"PLUGIN_LIBRARY{PLUGIN_LIBRARY}")
|
|||
|
|
# 初始化推理器
|
|||
|
|
yolo11_wrapper = YoLo11TRT(engine_file_path)
|
|||
|
|
print(f"engine_file_path{engine_file_path} {len(s3_url_list)}")
|
|||
|
|
try:
|
|||
|
|
for pic_url in s3_url_list:
|
|||
|
|
image_path = downFile(pic_url)
|
|||
|
|
print("1")
|
|||
|
|
frame = cv2.imread(image_path)
|
|||
|
|
print("2")
|
|||
|
|
# # 读取照片元信息,包括经纬度、飞行姿态
|
|||
|
|
# image_exif_dict = read_dji_exif_to_dict(image_path)
|
|||
|
|
batch_image_raw, result_box_list, result_scores_list, result_classid_list, infer_time = yolo11_wrapper.infer(
|
|||
|
|
[frame], confidence, categories)
|
|||
|
|
detected_frame = batch_image_raw[0]
|
|||
|
|
|
|||
|
|
# 提取边界框统计信息
|
|||
|
|
# box_detail = extract_box_details(last_results, model_id)
|
|||
|
|
# 保存结果图片
|
|||
|
|
local_path = f"output_frames/{os.path.basename(image_path)}"
|
|||
|
|
success = cv2.imwrite(local_path, detected_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|||
|
|
if success:
|
|||
|
|
print(f"图像已保存到: {local_path}")
|
|||
|
|
else:
|
|||
|
|
print("错误: 图像保存失败")
|
|||
|
|
date_str = str(datetime.datetime.now().timestamp())
|
|||
|
|
|
|||
|
|
minio_path_before, file_type_before = upload_file(image_path, None)
|
|||
|
|
minio_path, file_type = upload_file(local_path, date_str)
|
|||
|
|
os.remove(local_path)
|
|||
|
|
message = {
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"minio": {
|
|||
|
|
"minio_path": minio_path,
|
|||
|
|
"file_type": file_type,
|
|||
|
|
"original_path": minio_path_before
|
|||
|
|
},
|
|||
|
|
"box_detail": {
|
|||
|
|
"result_box_list": result_box_list,
|
|||
|
|
"result_scores_list": result_scores_list,
|
|||
|
|
"result_classid_list": result_classid_list
|
|||
|
|
}
|
|||
|
|
# "box_detail": {
|
|||
|
|
# "result_box_list": result_box_list.tolist() if isinstance(result_box_list,
|
|||
|
|
# np.ndarray) else result_box_list,
|
|||
|
|
# "result_scores_list": result_scores_list.tolist() if isinstance(result_scores_list,
|
|||
|
|
# np.ndarray) else result_scores_list,
|
|||
|
|
# "result_classid_list": result_classid_list.tolist() if isinstance(result_classid_list,
|
|||
|
|
# np.ndarray) else result_classid_list
|
|||
|
|
# }
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
json_message = json.dumps(message, indent=4, ensure_ascii=False)
|
|||
|
|
mqtt_client.publish_message(json_message)
|
|||
|
|
time.sleep(0.1)
|
|||
|
|
if os.path.exists(minio_path):
|
|||
|
|
os.remove(minio_path)
|
|||
|
|
if os.path.exists(minio_path_before):
|
|||
|
|
os.remove(minio_path_before)
|
|||
|
|
except Exception as infer_err:
|
|||
|
|
print(f"推理错误: {infer_err}")
|
|||
|
|
finally:
|
|||
|
|
yolo11_wrapper.destroy()
|
|||
|
|
except Exception as infer_err:
|
|||
|
|
print(f"pic_detect 推理错误: {infer_err}")
|