1536 lines
60 KiB
Python
1536 lines
60 KiB
Python
import asyncio
|
||
import json
|
||
|
||
import cv2
|
||
import subprocess
|
||
from threading import Thread, Lock, Event
|
||
import time
|
||
import queue
|
||
import numpy as np
|
||
|
||
import datetime
|
||
import os
|
||
from ultralytics import YOLO # 导入 Ultralytics YOLO 模型
|
||
|
||
from func_aimodellist.func_1x import func_100004, func_100000, func_100021
|
||
from middleware.minio_util import upload_file
|
||
# from Ai_tottle.minio_oss import upload_file
|
||
from mqtt_pub import MQTTClient
|
||
|
||
from ultralytics import solutions
|
||
import torch
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
from yolo.yolo_multi_model import YOLOModel, convert_result_to_multi_result, multi_model_process_frame
|
||
import av
|
||
|
||
|
||
|
||
# 全局变量
|
||
ifAI = {'status': False}
|
||
deskLock = Lock()
|
||
frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小
|
||
processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列
|
||
|
||
list_track_id = [] # 缓存track_id 进而方便做事件判断
|
||
|
||
stop_event = Event()
|
||
|
||
mqtt_client = None
|
||
# MQTT 代理地址和端口
|
||
# broker = "112.44.103.230" # 公共 MQTT 代理(免费)
|
||
broker = "8.137.54.85" # 公共 MQTT 代理(免费)
|
||
port = 1883 # MQTT 默认端口
|
||
# 主题
|
||
topic = "thing/product/ai/events"
|
||
|
||
local_drc_message = None
|
||
CHINESE_LABELS = {0: "人", 1: "车", 2: "车", 3: "猫", 4: "购", 5: "猫", 6: "猫", 7: "猫", 8: "猫", 9: "猫", 10: "猫"}
|
||
|
||
fps = 30
|
||
|
||
# 创建 RTMP 输出容器(显式指定 format="flv")
|
||
# container = av.open(rtmp_url, mode="w", format="flv")
|
||
|
||
# 初始不设置宽高,等待第一帧获取尺寸
|
||
stream = None
|
||
container = None
|
||
|
||
first_frame_received = False
|
||
|
||
|
||
# font_path="config/SIMSUN.TTC" 为中文配置文件路径,否则标签对中文支持不好,linux上要做特殊处理
|
||
|
||
def put_chinese_text(img, text, position, font_path="config/SIMSUN.TTC", font_size=20, color=(0, 255, 0)):
|
||
|
||
# def put_chinese_text(img, text, position, font_path="simhei.ttf", font_size=20, color=(0, 255, 0)):
|
||
"""使用PIL库在图片上绘制中文"""
|
||
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(img_pil)
|
||
|
||
try:
|
||
font = ImageFont.truetype(font_path, font_size)
|
||
except:
|
||
font = ImageFont.load_default()
|
||
|
||
draw.text(position, text, font=font, fill=color)
|
||
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
||
|
||
|
||
# 针对tracker的结果做统计输出
|
||
def cal_tricker_results(frame, results, class_names,
|
||
func_id, local_func_cache, para, model_cls,
|
||
chinese_label,int_func_id):
|
||
global list_track_id
|
||
"""Draw detection results on the frame."""
|
||
|
||
list_func_id=[int_func_id]
|
||
annotated_frame = frame.copy()
|
||
|
||
result_boxes = results.boxes
|
||
result_clss = results.clss
|
||
result_conf = results.confs
|
||
|
||
# print(f"result_clss {len(result_boxes)} {len(result_clss)} {len(result_conf)}")
|
||
# print("cal_t 1")
|
||
for i, box in enumerate(result_boxes):
|
||
class_id = result_clss[i]
|
||
conf = result_conf[i]
|
||
if isinstance(box, (torch.Tensor, np.ndarray)):
|
||
x1, y1, x2, y2 = map(int, box[:4])
|
||
else:
|
||
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
||
label = chinese_label.get(class_id, str(class_id))
|
||
text = f"{label} {conf:.2f}"
|
||
|
||
# 绘制边界框
|
||
cv2.rectangle(annotated_frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
||
|
||
# 计算文本位置(确保在框上方显示)
|
||
text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||
text_width, text_height = text_size[0], text_size[1]
|
||
|
||
# 调整文本位置到框上方
|
||
text_x = x1
|
||
text_y = y1 - 5 # 稍微向上移动5像素
|
||
|
||
# 如果文本超出图像顶部,则放在框内部下方
|
||
if text_y < 0:
|
||
text_y = y2 + text_height + 5
|
||
|
||
# 使用PIL绘制中文标签(先绘制背景矩形增强可读性)
|
||
# 创建背景矩形
|
||
bg_color = (0, 0, 0) # 黑色背景
|
||
padding = 2
|
||
annotated_frame = cv2.rectangle(
|
||
annotated_frame,
|
||
(int(text_x) - padding, int(text_y) - text_height - padding),
|
||
(int(text_x) + text_width + padding, int(text_y) + padding),
|
||
bg_color,
|
||
-1
|
||
)
|
||
|
||
# 再绘制文本(使用PIL绘制中文)
|
||
temp_img = annotated_frame.copy()
|
||
annotated_frame = put_chinese_text(
|
||
temp_img,
|
||
text,
|
||
(text_x, text_y - text_height), # PIL的坐标系不同,需要调整
|
||
font_size=15,
|
||
color=(0, 255, 0)
|
||
)
|
||
|
||
box_result = []
|
||
|
||
# print("cal_t 2")
|
||
# 当前是伪代码,针对人车识别模型
|
||
# human_pt={0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car', 4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle', 8: 'bus', 9: 'motor'}
|
||
# human_pt = {0: "行人", 1: "人", 2: "自行车", 3: "车", 4: "面包车", 5: "货车", 6: "三轮车", 7: "带篷三轮车",
|
||
# 8: "巴士", 9: "摩托车"}
|
||
|
||
# 接下来部分是功能筛选代码,针对100001、100002 功能代码,做部分逻辑判断
|
||
# 发现人 即报警
|
||
# 伪代码,暂时写死
|
||
if 100014 in list_func_id:
|
||
list_func_id.append(100000)
|
||
|
||
func_id_100000 = 100000
|
||
if func_id_100000 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
# for i in model_cls:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100001 = func_100000(results, model_cls, chinese_label, func_id_100000, list_track_id,func_id)
|
||
|
||
if result_100001 is not None:
|
||
# print(f"result_100001result_100001 {result_100001}")
|
||
box_result.append(result_100001)
|
||
# box_result["result_100001"] = result_100001
|
||
|
||
# print("cal_t 3")
|
||
func_id_100002 = 100002
|
||
if func_id_100002 in list_func_id:
|
||
# index = [2,3,4,5,6,7,8,9]
|
||
# type_name_list=[]
|
||
# for i in index:
|
||
# type_name_list.append(human_pt.get(i))
|
||
# print("c 1")
|
||
result_100002 = func_100000(results, model_cls, chinese_label, func_id_100002, list_track_id)
|
||
# print("c 2")
|
||
if result_100002 is not None:
|
||
# print(f"result_100001result_100001 {result_100002}")
|
||
box_result.append(result_100002)
|
||
# box_result["result_100001"] = result_100001
|
||
|
||
func_id_100004 = 100004
|
||
if func_id_100004 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
local_cache = local_func_cache["func_100004"]
|
||
# for i in index:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100004, local_cache = func_100004(results, model_cls, chinese_label, func_id_100004, list_track_id,
|
||
local_cache)
|
||
if result_100004 is not None:
|
||
local_func_cache["func_100004"] = local_cache # 更新全局缓存
|
||
box_result.append(result_100004)
|
||
# print(f"result_100004 {result_100004}")
|
||
|
||
# print("cal_t 4")
|
||
func_id_100006 = 100006
|
||
if func_id_100006 in list_func_id:
|
||
# index = [2,3,4,5,6,7,8,9]
|
||
# type_name_list=[]
|
||
local_cache = local_func_cache["func_100006"]
|
||
# for i in index:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100006, local_cache = func_100004(results, model_cls, chinese_label, func_id_100006, list_track_id,
|
||
local_cache)
|
||
if result_100006 is not None:
|
||
local_func_cache["func_100006"] = local_cache # 更新全局缓存
|
||
box_result.append(result_100006)
|
||
# print(f"result_100006 {result_100006}")
|
||
# print("cal_t 5")
|
||
|
||
func_id_100021 = 100021
|
||
if func_id_100021 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
number_n = para["N"] # rest接口中的N值,做人员聚集、车辆聚集的判断参数
|
||
# local_cache=local_func_cache["func_id_100021"]
|
||
# for i in index:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100021 = func_100021(results, model_cls, chinese_label, func_id_100021, number_n)
|
||
if result_100021 is not None:
|
||
box_result.append(result_100021)
|
||
# print(f"result_100021 {result_100021}")
|
||
|
||
func_id_100023 = 100023
|
||
if func_id_100023 in list_func_id:
|
||
# index = [2,3,4,5,6,7,8,9]
|
||
# type_name_list=[]
|
||
number_n = para["N"] # rest接口中的N值,做人员聚集、车辆聚集的判断参数
|
||
# local_cache=local_func_cache["func_100006"]
|
||
# for i in index:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100023 = func_100021(results, model_cls, chinese_label, func_id_100023, number_n)
|
||
if result_100023 is not None:
|
||
# local_func_cache["func_100006"]=local_cache #更新全局缓存
|
||
box_result.append(result_100023)
|
||
# print(f"result_100023 {result_100023}")
|
||
|
||
# bridge_pt={0:"蜂窝",1:"剥落",2:"空腔",3:"锈蚀",4:"钢筋裸露",5:"裂缝"}
|
||
|
||
func_id_100031 = 100031
|
||
if func_id_100031 in list_func_id:
|
||
# index = [0,1,2,3,4,5]
|
||
# type_name_list=[]
|
||
# for i in index:
|
||
# type_name_list.append(bridge_pt.get(i))
|
||
result_100031 = func_100000(results, model_cls, chinese_label, func_id_100031, list_track_id)
|
||
if result_100031 is not None:
|
||
# print(f"result_100031result_100031 {result_100031}")
|
||
box_result.append(result_100031)
|
||
|
||
# smoke_pt={0:"烟雾"}
|
||
|
||
func_id_100041 = 100041
|
||
if func_id_100041 in list_func_id:
|
||
# index = [0]
|
||
# type_name_list=[]
|
||
# for i in index:
|
||
# type_name_list.append(smoke_pt.get(i))
|
||
result_100041 = func_100000(results, model_cls, chinese_label, func_id_100041, list_track_id)
|
||
if result_100041 is not None:
|
||
# print(f"result_100041result_100041 {result_100041}")
|
||
box_result.append(result_100041)
|
||
|
||
# hwrc_pt={0:"人",1:"卡车",2:"汽车",3:"自行车"}
|
||
|
||
func_id_100061 = 100061
|
||
if func_id_100061 in list_func_id:
|
||
# index = [0,1,2,3]
|
||
# type_name_list=[]
|
||
# for i in index:
|
||
# type_name_list.append(hwrc_pt.get(i))
|
||
result_100061 = func_100000(results, model_cls, chinese_label, func_id_100061, list_track_id)
|
||
if result_100061 is not None:
|
||
# print(f"result_100061result_100061 {result_100061}")
|
||
box_result.append(result_100061)
|
||
|
||
# hwgf_pt={0:"单一热班",1:"大面积热斑",2:"单一热班&异常低温",3:"大面积热班&异常低温",4:"异常低温",5:"单一热斑&遮挡",6:"异常低温&热斑",7:"二极管短路",8:"阳光反射",9:"二极管短路&异常低温"}
|
||
|
||
func_id_100081 = 100081
|
||
if func_id_100081 in list_func_id:
|
||
# index = [0,1,2,3,4,5,6,7,8,9]
|
||
# type_name_list=[]
|
||
# for i in index:
|
||
# type_name_list.append(hwgf_pt.get(i))
|
||
result_100081 = func_100000(results, model_cls, chinese_label, func_id_100081, list_track_id)
|
||
if result_100081 is not None:
|
||
# print(f"result_100081result_100081 {result_100081}")
|
||
box_result.append(result_100081)
|
||
|
||
# # # 缩放一下,好展示
|
||
# resized_img = cv2.resize(annotated_frame, (0, 0), fx=0.5, fy=0.5) # 按比例缩放
|
||
# cv2.imshow("annotated_frame",resized_img)
|
||
# cv2.waitKey(0) # 等待任意按键,参数为延迟毫秒(0表示无限等待)
|
||
# cv2.destroyAllWindows() # 关闭窗口
|
||
|
||
|
||
func_id_100008 = 100008
|
||
if func_id_100008 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
# for i in model_cls:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100008 = func_100000(results, model_cls, chinese_label, func_id_100008, list_track_id)
|
||
if result_100008 is not None:
|
||
# print(f"result_100001result_100001 {result_100001}")
|
||
box_result.append(result_100008)
|
||
|
||
|
||
func_id_100052 = 100052
|
||
if func_id_100052 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
# for i in model_cls:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100052 = func_100000(results, model_cls, chinese_label, func_id_100052, list_track_id)
|
||
if result_100052 is not None:
|
||
# print(f"result_100001result_100001 {result_100001}")
|
||
box_result.append(result_100052)
|
||
|
||
func_id_100091 = 100091
|
||
if func_id_100091 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
# for i in model_cls:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100091 = func_100000(results, model_cls, chinese_label, func_id_100091, list_track_id)
|
||
if result_100091 is not None:
|
||
# print(f"result_100001result_100001 {result_100001}")
|
||
box_result.append(result_100091)
|
||
|
||
func_id_100092 = 100092
|
||
if func_id_100092 in list_func_id:
|
||
# index = [0,1]
|
||
# type_name_list=[]
|
||
# for i in model_cls:
|
||
# type_name_list.append(human_pt.get(i))
|
||
result_100092 = func_100000(results, model_cls, chinese_label, func_id_100092, list_track_id)
|
||
if result_100092 is not None:
|
||
# print(f"result_100001result_100001 {result_100001}")
|
||
box_result.append(result_100092)
|
||
|
||
|
||
|
||
return annotated_frame, box_result
|
||
|
||
|
||
# 输入为检测结果,输出为检测分类,统计后方便前端使用
|
||
def extract_box_details(last_results, model_id):
|
||
"""
|
||
从 YOLO 检测结果中提取边界框的详细信息
|
||
:param last_results: YOLO 检测结果对象
|
||
:param model_id: 模型id
|
||
:return: 包含边界框事件和数量的字典
|
||
"""
|
||
# box_event = []
|
||
box_count = []
|
||
count_dict = {}
|
||
human_pt = {0: "行人", 1: "人", 2: "自行车", 3: "车", 4: "面包车", 5: "货车", 6: "三轮车", 7: "带篷三轮车",
|
||
8: "巴士", 9: "摩托车"}
|
||
for box in last_results.boxes:
|
||
cls_index = int(box.cls[0]) # 类别索引
|
||
if cls_index in count_dict:
|
||
count_dict[cls_index] += 1
|
||
else:
|
||
count_dict[cls_index] = 1
|
||
|
||
for key, value in count_dict.items():
|
||
type_name = human_pt.get(key)
|
||
box_count.append({
|
||
"type": key,
|
||
"type_name": type_name, # 伪代码,后续需要修改
|
||
"count": value
|
||
})
|
||
|
||
box_detail = {
|
||
"model_id": model_id,
|
||
"box_count": box_count
|
||
}
|
||
|
||
return box_detail
|
||
|
||
|
||
def setIfAI(pb1):
|
||
deskLock.acquire()
|
||
ifAI['status'] = pb1
|
||
deskLock.release()
|
||
|
||
|
||
def getIfAI():
|
||
return ifAI['status']
|
||
|
||
|
||
def stopAIVideo():
|
||
print("正在停止AI视频处理...")
|
||
setIfAI(False)
|
||
stop_event.set()
|
||
|
||
# 等待足够长的时间确保资源释放
|
||
wait_count = 0
|
||
max_wait = 5 # 减少最大等待时间到5秒
|
||
|
||
while stop_event.is_set() and wait_count < max_wait:
|
||
time.sleep(0.5)
|
||
wait_count += 1
|
||
|
||
if wait_count >= max_wait:
|
||
print("警告: 停止AI视频处理超时,强制终止")
|
||
# 不使用_thread._interrupt_main(),改用其他方式强制终止
|
||
try:
|
||
# 尝试终止可能运行的进程
|
||
import os
|
||
import signal
|
||
import psutil
|
||
|
||
# 查找并终止可能的FFmpeg进程
|
||
current_process = psutil.Process(os.getpid())
|
||
for child in current_process.children(recursive=True):
|
||
try:
|
||
child_name = child.name().lower()
|
||
if 'ffmpeg' in child_name:
|
||
print(f"正在终止子进程: {child.pid} ({child_name})")
|
||
child.send_signal(signal.SIGTERM)
|
||
except:
|
||
pass
|
||
except:
|
||
pass
|
||
else:
|
||
print("AI视频处理已停止")
|
||
|
||
|
||
def startBackAIVideo(task_id, video_path,push_url, m1, model_cls, chinese_label, list_func_id, confidence, para, mqtt_ip,
|
||
mqtt_port, mqtt_topic):
|
||
if ifAI['status']:
|
||
stopAIVideo()
|
||
time.sleep(1)
|
||
stop_event.clear()
|
||
thread = Thread(target=startBackAIVideo2,
|
||
args=(
|
||
task_id, video_path, push_url,m1, list_func_id, confidence, para, mqtt_ip, mqtt_port, mqtt_topic,
|
||
model_cls,
|
||
chinese_label))
|
||
# cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence))
|
||
# cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出
|
||
thread.daemon = True # 守护线程,主程序退出时线程也会退出
|
||
|
||
thread.start()
|
||
|
||
|
||
def read_frames(cap,push_url, frame_queue):
|
||
"""优化的帧读取线程"""
|
||
global first_frame_received,stream,container
|
||
if container is None:
|
||
container = av.open(push_url, mode="w", format="flv")
|
||
frame_count = 0
|
||
last_time = time.time()
|
||
last_fps_time = time.time()
|
||
# 减小目标帧间隔时间,提高读取帧率
|
||
target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间(提高到60fps)
|
||
|
||
# 添加连接断开检测
|
||
connection_error_count = 0
|
||
max_connection_errors = 10 # 最多允许连续10次连接错误
|
||
last_successful_read = time.time()
|
||
max_read_wait = 60.0 # 30秒无法读取则认为连接断开
|
||
|
||
# 预先丢弃几帧,确保从新帧开始处理
|
||
for _ in range(5):
|
||
cap.grab()
|
||
retry_count = 0
|
||
|
||
max_retries = 3
|
||
|
||
while not stop_event.is_set() and retry_count < max_retries:
|
||
try:
|
||
# 确保container使用正确的push_url
|
||
if container is None:
|
||
try:
|
||
container = av.open(push_url, mode="w", format="flv")
|
||
print(f"成功打开RTMP推流: {push_url}")
|
||
except Exception as e:
|
||
print(f"无法打开RTMP推流地址: {push_url}, 错误: {e}")
|
||
stop_event.set()
|
||
return
|
||
|
||
current_time = time.time()
|
||
elapsed_time = current_time - last_time
|
||
# 检查是否长时间无法读取帧
|
||
if current_time - last_successful_read > max_read_wait:
|
||
print(f"超时警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开")
|
||
stop_event.set()
|
||
break
|
||
# 帧率控制,但更积极地读取
|
||
if elapsed_time < target_time_per_frame:
|
||
time.sleep(target_time_per_frame - elapsed_time)
|
||
continue
|
||
# print(f"4 {frame_queue.qsize()} {frame_queue.maxsize * 0.8}")
|
||
# 当队列快满时,跳过一些帧以避免延迟累积
|
||
if frame_queue.qsize() > frame_queue.maxsize * 0.8:
|
||
# 跳过一些帧
|
||
cap.grab()
|
||
last_time = time.time()
|
||
|
||
# 清空队列
|
||
while not frame_queue.empty():
|
||
try:
|
||
frame_queue.get_nowait() # 非阻塞地获取并移除队列中的元素
|
||
except queue.Empty:
|
||
break # 防止在极端情况下出现的竞争条件
|
||
# print(f"5 {frame_queue.qsize()} {frame_queue.maxsize * 0.8}")
|
||
# print(f"引入solutions.ObjectCounter,单帧处理速度变慢,队列堆积后,直接清空")
|
||
continue
|
||
# print(6)
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
print("拉流错误:无法读取帧")
|
||
connection_error_count += 1
|
||
if connection_error_count >= max_connection_errors:
|
||
print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...")
|
||
stop_event.set()
|
||
break
|
||
time.sleep(0.5) # 短暂等待后重试
|
||
continue
|
||
|
||
# 成功读取了帧,重置错误计数
|
||
connection_error_count = 0
|
||
last_successful_read = time.time()
|
||
|
||
frame_count += 1
|
||
if frame_count % 60 == 0: # 每60帧计算一次FPS
|
||
current_fps_time = time.time()
|
||
fps = 60 / (current_fps_time - last_fps_time)
|
||
# print(f"拉流FPS: {fps:.2f}")
|
||
last_fps_time = current_fps_time
|
||
|
||
last_time = time.time()
|
||
frame_queue.put((frame, time.time())) # 添加时间戳
|
||
height, width, channels = frame.shape
|
||
|
||
# 如果是第一帧,初始化流参数
|
||
if not first_frame_received:
|
||
stream = container.add_stream("libx264", rate=30)
|
||
stream.width = width
|
||
stream.height = height
|
||
stream.pix_fmt = "yuv420p"
|
||
stream.options = {
|
||
"tune": "zerolatency",
|
||
"preset": "ultrafast",
|
||
"crf": "23",
|
||
}
|
||
first_frame_received = True
|
||
print(f"Initialized stream with resolution: {width}x{height}")
|
||
|
||
if first_frame_received:
|
||
# 将 NumPy 数组转换为 AV 帧
|
||
av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
|
||
|
||
# 编码并推送帧
|
||
for packet in stream.encode(av_frame):
|
||
container.mux(packet)
|
||
|
||
|
||
except av.error.BrokenPipeError:
|
||
print("RTMP 连接断开,尝试重连...")
|
||
retry_count += 1
|
||
if container:
|
||
container.close()
|
||
container = None
|
||
time.sleep(2) # 等待2秒后重试
|
||
except Exception as e:
|
||
print(f"读取帧时发生错误: {e}")
|
||
break
|
||
|
||
|
||
|
||
|
||
|
||
def process_frames_tricker(push_url,frame_queue, processed_frame_queue, model_path, list_func_id, confidence, para, model_cls,
|
||
chinese_label, use_fp16=False):
|
||
"""Count specific classes of objects in a video with dynamic processing adjustments."""
|
||
global first_frame_received,stream,container
|
||
if container is None:
|
||
container = av.open(push_url, mode="w", format="flv")
|
||
|
||
# solutions默认会在图上绘制一个方框,很烦
|
||
line_points = [(-10, -10), (-10, -10)]
|
||
# 伪代码,后续识别类型需要跟ai_model_list 的id关联
|
||
# cls = [0, 1, 2, 3, 4]
|
||
# 这里参数show=True,一旦通过rest接口重启程序,会导致线程锁住
|
||
# counter = solutions.ObjectCounter(show=False, region=line_points, model=model_path, classes=cls)
|
||
counter = solutions.ObjectCounter(show=False, region=line_points, model=model_path, classes=model_cls)
|
||
|
||
|
||
# Model processing setup
|
||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||
try:
|
||
counter.model.to(device)
|
||
if hasattr(counter.model, 'args') and hasattr(counter.model.args, 'batch'):
|
||
counter.model.args.batch = 1
|
||
# 启用半精度
|
||
if use_fp16 and device.startswith('cuda') and hasattr(counter.model, 'model'):
|
||
try:
|
||
counter.model.model = counter.model.model.half()
|
||
print("✅ 启用半精度 FP16 模式")
|
||
except Exception as half_err:
|
||
print(f"⚠️ 半精度转换失败: {half_err}")
|
||
else:
|
||
print("ℹ️ 半精度模式未启用")
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 模型设备配置警告: {e}")
|
||
|
||
frame_count = 0
|
||
process_times = []
|
||
last_results = None
|
||
last_counter = None
|
||
skip_counter = 0
|
||
max_skip = 2
|
||
error_count = 0
|
||
|
||
# Get class names from the model
|
||
class_names = counter.model.names if hasattr(counter.model, 'names') else {i: str(i) for i in range(1000)}
|
||
|
||
local_func_cache = {
|
||
"func_100000": None,
|
||
"func_100004": None, # 存储缓存,缓存人员track_id
|
||
"func_100006": None # 存储缓存,缓存车辆track_id
|
||
|
||
}
|
||
|
||
while not stop_event.is_set():
|
||
try:
|
||
if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8:
|
||
time.sleep(0.01)
|
||
continue
|
||
|
||
frame, timestamp = frame_queue.get(timeout=0.2)
|
||
frame_copy = frame.copy()
|
||
if time.time() - timestamp > 0.3:
|
||
continue
|
||
|
||
frame_count += 1
|
||
|
||
if skip_counter > 0 and last_counter is not None:
|
||
skip_counter -= 1
|
||
annotated_frame, box_result = cal_tricker_results(frame, counter, class_names, list_func_id,
|
||
local_func_cache, para, model_cls, chinese_label)
|
||
continue
|
||
|
||
process_start = time.time()
|
||
|
||
resize_scale = 1.0
|
||
qsize = frame_queue.qsize()
|
||
maxsize = frame_queue.maxsize
|
||
|
||
if qsize > maxsize * 0.7:
|
||
resize_scale = 0.4
|
||
elif qsize > maxsize * 0.5:
|
||
resize_scale = 0.6
|
||
elif qsize > maxsize * 0.3:
|
||
resize_scale = 0.8
|
||
|
||
if resize_scale < 1.0:
|
||
process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale)
|
||
else:
|
||
process_frame = frame
|
||
|
||
if process_frame is None or process_frame.size == 0:
|
||
print("警告: 输入帧无效")
|
||
continue
|
||
|
||
try:
|
||
# Run object counting/tracking
|
||
# 获取当前时间的时间戳,包含毫秒
|
||
# timestamp = time.time()
|
||
# # 获取13位长度的时间戳(毫秒级)
|
||
# time_start = int(timestamp * 1000)
|
||
# frame_copy=process_frame # 暂存frame,方便进行结果重绘
|
||
timestart = time.time()
|
||
results = counter(process_frame)
|
||
timeend = time.time()
|
||
print(f"timestart-timeend {timeend-timestart}")
|
||
if results is None: # 处理返回None的情况
|
||
print("警告: 检测结果为None,使用上一帧结果")
|
||
# results = getattr(counter, 'last_results', None)
|
||
# if results is None:
|
||
continue
|
||
# timestamp = time.time()
|
||
# # 获取13位长度的时间戳(毫秒级)
|
||
# time_end = int(timestamp * 1000)
|
||
# print(f"time_end - time_start {time_end - time_start}")
|
||
# print("1")
|
||
last_counter = counter # 基于计数的方法,识别后的结果存储在counter中
|
||
# print("2")
|
||
annotated_frame, box_result = cal_tricker_results(
|
||
frame_copy if resize_scale == 1.0 else cv2.resize(frame_copy, (frame.shape[1], frame.shape[0])),
|
||
counter, class_names, list_func_id, local_func_cache, para, model_cls, chinese_label)
|
||
|
||
# if annotated_frame is None:
|
||
# print("annotated_frame is None")
|
||
#
|
||
# if box_result is None:
|
||
# print("box_result is None")
|
||
|
||
# print("3")
|
||
processed_frame_queue.put((annotated_frame, timestamp, box_result))
|
||
# print_box_details(results)
|
||
# print("=4")
|
||
|
||
|
||
|
||
|
||
if qsize > maxsize * 0.5:
|
||
skip_counter = max_skip
|
||
|
||
except Exception as infer_err:
|
||
print(f"推理错误1: {infer_err}")
|
||
|
||
if last_counter is not None:
|
||
annotated_frame, box_result = cal_tricker_results(frame, counter, class_names, list_func_id,
|
||
local_func_cache, para, model_cls, chinese_label)
|
||
else:
|
||
annotated_frame = frame.copy()
|
||
|
||
process_end = time.time()
|
||
process_times.append(process_end - process_start)
|
||
if len(process_times) > 30:
|
||
process_times.pop(0)
|
||
|
||
if frame_count % 30 == 0:
|
||
avg_process_time = sum(process_times) / len(process_times)
|
||
fps = 1.0 / avg_process_time if avg_process_time > 0 else 0
|
||
print(
|
||
f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time * 1000:.2f}ms, 队列大小: {qsize}, 缩放比例: {resize_scale:.2f}")
|
||
|
||
processed_frame_queue.put((annotated_frame, timestamp, box_result))
|
||
error_count = 0
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
print(f"处理帧错误: {e}")
|
||
error_count += 1
|
||
if error_count >= 5:
|
||
print(f"连续处理错误达到5次 ,正在停止处理...")
|
||
break
|
||
# finally:
|
||
# # 清理资源
|
||
# for packet in stream.encode():
|
||
# container.mux(packet)
|
||
# container.close()
|
||
|
||
|
||
def multi_model_process_frames(frame_queue, processed_frame_queue, model_path, list_func_id, confidence, para,
|
||
model_cls, chinese_label, use_fp16=False):
|
||
"""Count specific classes of objects in a video with dynamic processing adjustments."""
|
||
models = [
|
||
YOLOModel('pt/best.pt', cls_map={'pedestrian': '行人', 'pedestrian': '行人'},
|
||
allowed_classes=['pedestrian', 'people'])
|
||
,
|
||
YOLOModel('pt/best.pt', cls_map={'car': '车'}, allowed_classes=['car'])
|
||
]
|
||
|
||
# 用于存储处理结果的字典
|
||
result_dict = {}
|
||
frame_count = 0
|
||
processing_times = []
|
||
|
||
# 启动所有模型的工作线程
|
||
for model in models:
|
||
model.start()
|
||
|
||
local_func_cache = {
|
||
"func_100000": None,
|
||
"func_100004": None, # 存储缓存,缓存人员track_id
|
||
"func_100006": None # 存储缓存,缓存车辆track_id
|
||
|
||
}
|
||
|
||
max_skip = 2
|
||
error_count = 0
|
||
while not stop_event.is_set():
|
||
try:
|
||
if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8:
|
||
time.sleep(0.01)
|
||
continue
|
||
|
||
frame, timestamp = frame_queue.get(timeout=0.2)
|
||
frame_copy = frame.copy()
|
||
if time.time() - timestamp > 0.3:
|
||
continue
|
||
|
||
process_start = time.time()
|
||
|
||
resize_scale = 1.0
|
||
qsize = frame_queue.qsize()
|
||
maxsize = frame_queue.maxsize
|
||
|
||
# if qsize > maxsize * 0.7:
|
||
# resize_scale = 0.4
|
||
# elif qsize > maxsize * 0.5:
|
||
# resize_scale = 0.6
|
||
# elif qsize > maxsize * 0.3:
|
||
# resize_scale = 0.8
|
||
#
|
||
# if resize_scale < 1.0:
|
||
# process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale)
|
||
# else:
|
||
# process_frame = frame
|
||
|
||
# if process_frame is None or process_frame.size == 0:
|
||
# print("警告: 输入帧无效")
|
||
# continue
|
||
|
||
# 初始化多个YOLO模型
|
||
|
||
try:
|
||
timestart = time.time()
|
||
# 处理帧(多模型并行)
|
||
multi_model_process_frame(frame, models, result_dict)
|
||
timeend1 = time.time()
|
||
print(f"timeend1 {timeend1 - timestart}")
|
||
# 获取处理结果
|
||
frame_id = id(frame)
|
||
if frame_id in result_dict:
|
||
results = result_dict.pop(frame_id)
|
||
|
||
multi_result = convert_result_to_multi_result(results)
|
||
timeend2 = time.time()
|
||
print(f"timeend2 {timeend2 - timestart}")
|
||
annotated_frame, box_result = cal_tricker_results(frame, multi_result, None, list_func_id,local_func_cache, para, model_cls, chinese_label)
|
||
timeend3 = time.time()
|
||
print(f"timeend3 {timeend3 - timestart}")
|
||
time.sleep(1)
|
||
timeend4 = time.time()
|
||
print(f"timeend4 {timeend4 - timeend3}")
|
||
# cv2.imshow('annotated_frame Frame', frame)
|
||
# cv2.imwrite('output.jpg', frame)
|
||
# #
|
||
# # results = counter(process_frame)
|
||
# # if results is None: # 处理返回None的情况
|
||
# # print("警告: 检测结果为None,使用上一帧结果")
|
||
# # # results = getattr(counter, 'last_results', None)
|
||
# # # if results is None:
|
||
# # continue
|
||
# # timestamp = time.time()
|
||
# # # 获取13位长度的时间戳(毫秒级)
|
||
# # time_end = int(timestamp * 1000)
|
||
# # print(f"time_end - time_start {time_end - time_start}")
|
||
# # print("1")
|
||
# last_counter = counter # 基于计数的方法,识别后的结果存储在counter中
|
||
# # print("2")
|
||
# annotated_frame, box_result = cal_tricker_results(
|
||
# frame_copy if resize_scale == 1.0 else cv2.resize(frame_copy, (frame.shape[1], frame.shape[0])),
|
||
# counter, class_names, list_func_id,local_func_cache,para,model_cls,chinese_label)
|
||
#
|
||
# # if annotated_frame is None:
|
||
# # print("annotated_frame is None")
|
||
# #
|
||
# # if box_result is None:
|
||
# # print("box_result is None")
|
||
|
||
# print("3")
|
||
# processed_frame_queue.put((annotated_frame, timestamp, box_result))
|
||
# print_box_details(results)
|
||
# print("=4")
|
||
if qsize > maxsize * 0.5:
|
||
skip_counter = max_skip
|
||
|
||
except Exception as infer_err:
|
||
print(f"推理错误1: {infer_err}")
|
||
|
||
# if last_counter is not None:
|
||
# annotated_frame, box_result = cal_tricker_results(frame, counter, class_names, list_func_id,local_func_cache,para,model_cls,chinese_label)
|
||
# else:
|
||
# annotated_frame = frame.copy()
|
||
|
||
# process_end = time.time()
|
||
# process_times.append(process_end - process_start)
|
||
# if len(process_times) > 30:
|
||
# process_times.pop(0)
|
||
|
||
# if frame_count % 30 == 0:
|
||
# avg_process_time = sum(process_times) / len(process_times)
|
||
# fps = 1.0 / avg_process_time if avg_process_time > 0 else 0
|
||
# print(
|
||
# f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time * 1000:.2f}ms, 队列大小: {qsize}, 缩放比例: {resize_scale:.2f}")
|
||
#
|
||
# processed_frame_queue.put((annotated_frame, timestamp, box_result))
|
||
error_count = 0
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
print(f"处理帧错误: {e}")
|
||
error_count += 1
|
||
if error_count >= 5:
|
||
print(f"连续处理错误达到5次 ,正在停止处理...")
|
||
break
|
||
|
||
|
||
#
|
||
# def process_frames(model_id, frame_queue, processed_frame_queue, ov_model, cls, confidence, use_fp16=False):
|
||
# """处理帧的线程,添加帧率控制,支持半精度FP16"""
|
||
#
|
||
# import torch
|
||
# import time
|
||
# import queue
|
||
# import cv2
|
||
#
|
||
# error_count = 0 # 添加错误计数器
|
||
# max_errors = 5 # 最大容许错误次数
|
||
# frame_count = 0
|
||
# process_times = [] # 用于计算平均处理时间
|
||
#
|
||
# # 设置YOLO模型配置,提高性能
|
||
# ov_model.conf = confidence # 设置置信度阈值
|
||
#
|
||
# # 将模型移到设备(GPU或CPU)
|
||
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||
# try:
|
||
# ov_model.to(device)
|
||
# # 调整批处理大小为1,减少内存占用
|
||
# if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
|
||
# ov_model.args.batch = 1
|
||
# # 启用半精度
|
||
# if use_fp16 and device.startswith('cuda') and hasattr(ov_model, 'model'):
|
||
# try:
|
||
# ov_model.model = ov_model.model.half()
|
||
# print("✅ 启用半精度 FP16 模式")
|
||
# except Exception as half_err:
|
||
# print(f"⚠️ 半精度转换失败: {half_err}")
|
||
# else:
|
||
# print("ℹ️ 半精度模式未启用")
|
||
# except Exception as e:
|
||
# print(f"⚠️ 模型设备配置警告: {e}")
|
||
#
|
||
# # 缓存先前的检测结果,用于提高稳定性
|
||
# last_results = None
|
||
# skip_counter = 0
|
||
# max_skip = 2 # 最多跳过几帧不处理
|
||
#
|
||
# while not stop_event.is_set():
|
||
# try:
|
||
# if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8:
|
||
# time.sleep(0.01)
|
||
# continue
|
||
#
|
||
# frame, timestamp = frame_queue.get(timeout=0.2)
|
||
#
|
||
# if time.time() - timestamp > 0.3:
|
||
# continue
|
||
#
|
||
# frame_count += 1
|
||
#
|
||
# if skip_counter > 0 and last_results is not None:
|
||
# skip_counter -= 1
|
||
# annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
|
||
#
|
||
# box_detail = extract_box_details(last_results, model_id)
|
||
#
|
||
# processed_frame_queue.put((annotated_frame, timestamp, box_detail))
|
||
# continue
|
||
#
|
||
# process_start = time.time()
|
||
#
|
||
# resize_scale = 1.0
|
||
# qsize = frame_queue.qsize()
|
||
# maxsize = frame_queue.maxsize
|
||
#
|
||
# if qsize > maxsize * 0.7:
|
||
# resize_scale = 0.4
|
||
# elif qsize > maxsize * 0.5:
|
||
# resize_scale = 0.6
|
||
# elif qsize > maxsize * 0.3:
|
||
# resize_scale = 0.8
|
||
#
|
||
# if resize_scale < 1.0:
|
||
# process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale)
|
||
# else:
|
||
# process_frame = frame
|
||
#
|
||
# try:
|
||
# results = ov_model(process_frame, classes=cls, show=False)
|
||
# last_results = results[0]
|
||
#
|
||
# if resize_scale < 1.0:
|
||
# annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5),
|
||
# (frame.shape[1], frame.shape[0]))
|
||
# # print("results[0].boxes.cls ", results[0].boxes.cls)
|
||
#
|
||
# else:
|
||
# annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5)
|
||
#
|
||
# box_detail = extract_box_details(last_results, model_id)
|
||
#
|
||
# if qsize > maxsize * 0.5:
|
||
# skip_counter = max_skip
|
||
# except Exception as infer_err:
|
||
# print(f"推理错误: {infer_err}")
|
||
# box_detail = None
|
||
# if last_results is not None:
|
||
# annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
|
||
# else:
|
||
# annotated_frame = frame.copy()
|
||
#
|
||
# process_end = time.time()
|
||
# process_times.append(process_end - process_start)
|
||
# if len(process_times) > 30:
|
||
# process_times.pop(0)
|
||
#
|
||
# if frame_count % 30 == 0:
|
||
# avg_process_time = sum(process_times) / len(process_times)
|
||
# fps = 1.0 / avg_process_time if avg_process_time > 0 else 0
|
||
# print(
|
||
# f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time * 1000:.2f}ms, 队列大小: {qsize}, 缩放比例: {resize_scale:.2f}")
|
||
#
|
||
# processed_frame_queue.put((annotated_frame, timestamp, box_detail))
|
||
# error_count = 0
|
||
# except queue.Empty:
|
||
# continue
|
||
# except Exception as e:
|
||
# error_count += 1
|
||
# print(f"处理帧错误: {e}")
|
||
# if error_count >= max_errors:
|
||
# print(f"连续处理错误达到{max_errors}次,正在停止处理...")
|
||
# stop_event.set()
|
||
# break
|
||
|
||
|
||
# # 读取mqtt的drc消息,并且在本地维持一个最新消息
|
||
# def read_drc_mqtt(mqtt_client):
|
||
# global local_drc_message
|
||
# while not stop_event.is_set():
|
||
# time.sleep(0.05)
|
||
# # 更新全局缓存,如果要同时适配多个飞机,这里要改为队列
|
||
# mess = mqtt_client.get_messages(timeout=0.5)
|
||
# if mess is not None:
|
||
# # local_drc_message = mess
|
||
# local_drc_message = json.loads(mess) # 解析为字典
|
||
|
||
# def read_drc_mqtt(mqtt_client):
|
||
# global local_drc_message
|
||
# while not stop_event.is_set():
|
||
# time.sleep(0.05)
|
||
# try:
|
||
# # 确保 mess 是字符串
|
||
# mess = mqtt_client.get_messages(timeout=0.5)
|
||
# if mess is not None:
|
||
# try:
|
||
# local_drc_message = json.loads(mess) # 确保 mess 是字符串
|
||
# except json.JSONDecodeError as e:
|
||
# print(f"JSON 解析错误: {e}")
|
||
# except Exception as e:
|
||
# print(f"读取 DRC 消息错误: {e}")
|
||
|
||
|
||
|
||
def get_local_drc_message():
|
||
global local_drc_message
|
||
if local_drc_message is not None:
|
||
mess =local_drc_message
|
||
local_drc_message=None
|
||
return mess
|
||
return None
|
||
|
||
from asyncio_mqtt import Client
|
||
|
||
async def async_read_drc_mqtt(mqtt_ip, mqtt_port, mqtt_topic):
|
||
global local_drc_message
|
||
async with Client(mqtt_ip, mqtt_port) as client:
|
||
async with client.messages() as messages:
|
||
await client.subscribe(mqtt_topic)
|
||
async for message in messages:
|
||
if message.topic.matches(mqtt_topic):
|
||
try:
|
||
data = json.loads(message.payload)
|
||
if (data and
|
||
isinstance(data.get("data"), dict) and
|
||
"attitude_head" in data["data"]):
|
||
local_drc_message = data
|
||
except Exception as e:
|
||
logger.error(f"Error processing MQTT message: {e}")
|
||
|
||
def read_drc_mqtt(mqtt_client):
|
||
global local_drc_message
|
||
while not stop_event.is_set():
|
||
time.sleep(0.05)
|
||
# 更新全局缓存,如果要同时适配多个飞机,这里要改为队列
|
||
mess = mqtt_client.get_messages(timeout=0.5)
|
||
if mess is not None:
|
||
local_drc_message = mess
|
||
|
||
|
||
def write_frames(processed_frame_queue, pipe, size, mqtt_client, task_id):
|
||
# # 测试用做,输出识别后的图片到窗口
|
||
# cv2.namedWindow("RTMP Stream with YOLO", cv2.WINDOW_NORMAL)
|
||
# cv2.resizeWindow("RTMP Stream with YOLO", 900, 600)
|
||
global local_drc_message
|
||
"""写入帧的线程,添加平滑处理"""
|
||
last_write_time = time.time()
|
||
target_time_per_frame = 1.0 / 30.0 # 30fps
|
||
pipe_error_count = 0 # 添加错误计数
|
||
max_pipe_errors = 3 # 最大容忍错误数
|
||
frame_count = 0
|
||
last_fps_time = time.time()
|
||
skipped_frames = 0
|
||
|
||
# 使用队列存储最近几帧,用于在需要时进行插值
|
||
recent_frames = []
|
||
max_recent_frames = 5 # 增加缓存帧数量,提高平滑性
|
||
|
||
# 使用双缓冲机制提高写入速度
|
||
buffer1 = bytearray(size[0] * size[1] * 3)
|
||
buffer2 = bytearray(size[0] * size[1] * 3)
|
||
current_buffer = buffer1
|
||
|
||
# 帧率控制参数
|
||
min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔
|
||
max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔
|
||
|
||
while not stop_event.is_set():
|
||
try:
|
||
# 获取处理后的帧,超时时间较短以便更平滑地处理
|
||
frame, timestamp, box_detail = processed_frame_queue.get(timeout=0.05)
|
||
|
||
# 存储最近的帧用于插值
|
||
recent_frames.append(frame)
|
||
if len(recent_frames) > max_recent_frames:
|
||
recent_frames.pop(0)
|
||
|
||
current_time = time.time()
|
||
elapsed = current_time - last_write_time
|
||
|
||
# 如果两帧间隔太短,考虑合并或跳过
|
||
if elapsed < min_frame_interval and len(recent_frames) >= 2:
|
||
skipped_frames += 1
|
||
continue
|
||
|
||
# 如果两帧间隔太长,进行插值平滑
|
||
if elapsed > max_frame_interval and len(recent_frames) >= 2:
|
||
# 计算需要插入的帧数
|
||
frames_to_insert = min(int(elapsed / target_time_per_frame), 3)
|
||
|
||
for i in range(frames_to_insert):
|
||
# 创建插值帧
|
||
weight = (i + 1) / (frames_to_insert + 1)
|
||
interpolated_frame = cv2.addWeighted(recent_frames[-2], 1 - weight, recent_frames[-1], weight, 0)
|
||
|
||
# 切换双缓冲
|
||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||
|
||
# 高效调整大小并写入
|
||
interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
|
||
# 在窗口输出图像
|
||
# cv2.imshow("RTMP Stream with YOLO", interpolated_resized)
|
||
|
||
# img_bytes = interpolated_resized.tobytes()
|
||
#
|
||
# # 写入管道
|
||
# pipe.stdin.write(img_bytes)
|
||
# pipe.stdin.flush()
|
||
|
||
# 切换双缓冲
|
||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||
|
||
# 正常写入当前帧 - 使用高效的调整大小方法
|
||
resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR)
|
||
# cv2.imshow()
|
||
# cv2.imshow("RTMP Stream with YOLO", resized_frame)
|
||
|
||
frame_time = time.time_ns()
|
||
date_str = datetime.datetime.now().strftime("%Y%m%d") # 避免变量覆盖
|
||
|
||
local_path = f"saveImg/output-{frame_time}.jpg"
|
||
frame_count = frame_count + 1
|
||
# # 用作本地测试,保存图像;测试逻辑,随机暂时几张
|
||
if frame_count % 10 == 0:
|
||
success = cv2.imwrite(local_path, resized_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||
if success:
|
||
print(f"图像已保存到: {frame_count}")
|
||
# 上传文件到 Minio
|
||
minio_path, file_type = upload_file(local_path, date_str)
|
||
minio_path, file_type = upload_file(local_path, date_str)
|
||
|
||
message = {
|
||
"task_id": task_id,
|
||
"minio": {
|
||
"minio_path": minio_path,
|
||
"file_type": file_type
|
||
},
|
||
"box_detail": box_detail,
|
||
"uav_location": local_drc_message
|
||
}
|
||
message_json = json.dumps(message, indent=4, ensure_ascii=False)
|
||
|
||
# mqtt_client.put_message(str(message))
|
||
mqtt_client.publish_message(message_json)
|
||
else:
|
||
print("错误: 图像保存失败")
|
||
|
||
if os.path.exists(local_path):
|
||
try:
|
||
os.remove(local_path)
|
||
print(f"文件 {local_path} 已删除")
|
||
except Exception as e:
|
||
print(f"删除失败: {e}")
|
||
# img_bytes = resized_frame.tobytes()
|
||
# pipe.stdin.write(img_bytes)
|
||
# pipe.stdin.flush()
|
||
#
|
||
# frame_count += 1
|
||
if frame_count % 30 == 0:
|
||
current_fps_time = time.time()
|
||
fps = 30 / (current_fps_time - last_fps_time)
|
||
print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}")
|
||
last_fps_time = current_fps_time
|
||
skipped_frames = 0
|
||
|
||
last_write_time = time.time()
|
||
pipe_error_count = 0 # 成功写入后重置错误计数
|
||
|
||
except queue.Empty:
|
||
# 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅
|
||
if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame:
|
||
try:
|
||
# 创建插值帧
|
||
interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0)
|
||
|
||
# 切换双缓冲
|
||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||
|
||
resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
|
||
# img_bytes = resized_frame.tobytes()
|
||
# pipe.stdin.write(img_bytes)
|
||
# pipe.stdin.flush()
|
||
last_write_time = time.time()
|
||
except Exception:
|
||
pass
|
||
continue
|
||
except Exception as e:
|
||
print(f"写入帧错误: {e}")
|
||
pipe_error_count += 1
|
||
if pipe_error_count >= max_pipe_errors:
|
||
print("FFmpeg管道错误过多,正在终止进程...")
|
||
stop_event.set() # 主动结束所有线程
|
||
break
|
||
|
||
|
||
#
|
||
# def cls2_find(video_path, m1, cls, confidence):
|
||
# try:
|
||
# ov_model = YOLO(m1)
|
||
#
|
||
# # ---------------------------------MinIO 存储路径(用于后续上传)------------------------------------
|
||
# minio_path = "AIResults"
|
||
#
|
||
# # -------------------------------获取当前日期,用于存储图像目录--------------------------------------
|
||
# date_str = datetime.datetime.now().strftime("%Y%m%d")
|
||
# save_dir = f"{date_str}"
|
||
# if not os.path.exists(save_dir):
|
||
# os.makedirs(save_dir)
|
||
#
|
||
# # 打开视频流
|
||
# cap = cv2.VideoCapture(video_path)
|
||
# if not cap.isOpened():
|
||
# print("Error: Could not open video.")
|
||
# return
|
||
#
|
||
# # 获取视频的帧率 (fps)
|
||
# fps = cap.get(cv2.CAP_PROP_FPS)
|
||
#
|
||
# # -----------------------------根据模型设置类别-------------------------------------------------
|
||
# if m1 == "gdaq.pt": # 仅当使用 gdaq.pt 时,保存类别 2 和 4
|
||
# cls2 = [2, 4]
|
||
# elif m1 == "best.pt": # 仅当使用 best.pt 时,保存类别 0
|
||
# cls2 = [0]
|
||
# else: # 其它模型不保存
|
||
# cls2 = []
|
||
#
|
||
# # ------------------------------------------cls2检测--------------------------------------------
|
||
# skip_frames = int(fps * 10) # 设置跳过帧数为 10 秒
|
||
#
|
||
# while cap.isOpened() and not stop_event.is_set():
|
||
# if skip_frames > 0:
|
||
# skip_frames -= 1 # 逐帧减少
|
||
# cap.grab() # 仅抓取帧,不进行解码
|
||
# continue # 跳过处理
|
||
#
|
||
# ret, frame = cap.read()
|
||
# if not ret:
|
||
# break # 无法读取帧时退出
|
||
#
|
||
# # 目标检测
|
||
# results = ov_model(frame, conf=confidence, classes=cls, show=False)
|
||
#
|
||
# for result in results:
|
||
# for box in result.boxes:
|
||
# cls_index = int(box.cls[0]) # 获取类别索引
|
||
#
|
||
# # 如果检测到的类别在 cls2 里,跳过 10 秒
|
||
# if cls_index in cls2:
|
||
# skip_frames = int(fps * 10) # 设置跳过帧数为 10 秒
|
||
# # upload_and_insert_to_db(frame, ov_model, cls_index, save_dir, minio_path)
|
||
# filename = f"{save_dir}/frame_{int(cap.get(cv2.CAP_PROP_POS_FRAMES))}_cls2.jpg"
|
||
# cv2.imwrite(filename, frame)
|
||
# print(f"保存图像: {filename}")
|
||
#
|
||
# # 绘制检测框
|
||
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||
# label = f"{result.names[cls_index]} {box.conf[0]:.2f}"
|
||
# cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||
# cv2.putText(frame, label, (x1, y1 - 10),
|
||
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
|
||
# except Exception as e:
|
||
# print(f"cls2_find错误: {e}")
|
||
# finally:
|
||
# if 'cap' in locals() and cap is not None:
|
||
# cap.release()
|
||
|
||
|
||
# 后台计算的主要逻辑,识别图像,进而将事件通过mqtt转发、minio存储
|
||
# 1、识别视频流中的事件,进而将事件的前后帧做存储
|
||
# 2、事件基于mqtt 做转发,mqtt可以包含事件,也可以不包含事件
|
||
def startBackAIVideo2(task_id, video_path,push_url, m1, list_func_id, confidence, para, mqtt_ip, mqtt_port, mqtt_topic,
|
||
model_cls, chinese_label):
|
||
global mqtt_client, local_drc_message
|
||
setIfAI(True)
|
||
|
||
cap = None
|
||
read_thread = None
|
||
process_thread = None
|
||
write_thread = None
|
||
read_mqtt_drc_message = None
|
||
ov_model = None
|
||
|
||
local_drc_message = None # 清空数据
|
||
|
||
try:
|
||
# 设置环境变量,提高YOLO性能
|
||
os.environ["OMP_NUM_THREADS"] = "4"
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
||
# 导入必要的库
|
||
import torch
|
||
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
|
||
if torch.cuda.is_available():
|
||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||
torch.cuda.empty_cache()
|
||
|
||
# 预加载YOLO模型
|
||
print("预加载YOLO模型...")
|
||
max_retries = 3
|
||
retry_count = 0
|
||
model_params = {}
|
||
# 尝试检测YOLO版本以适配不同版本的参数
|
||
try:
|
||
# 先简单尝试加载模型,不带任何参数
|
||
test_model = YOLO(m1)
|
||
# 如果成功,检查可用的参数
|
||
if hasattr(test_model, "task"):
|
||
model_params["task"] = "detect" # 指定任务类型
|
||
|
||
# 检查是否支持half精度
|
||
if torch.cuda.is_available():
|
||
model_params["half"] = True
|
||
|
||
# 检查是否支持verbose参数
|
||
import inspect
|
||
if "verbose" in inspect.signature(YOLO.__init__).parameters:
|
||
model_params["verbose"] = False
|
||
|
||
print(f"检测到支持的YOLO参数: {model_params}")
|
||
except Exception as e:
|
||
print(f"参数检测失败,将使用默认参数: {e}")
|
||
model_params = {}
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
|
||
ov_model = YOLO(m1, **model_params)
|
||
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
|
||
for _ in range(3):
|
||
# 伪代码,后续需要关联ai_model_list 的id
|
||
cls = [0]
|
||
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
|
||
print("YOLO模型加载成功并预热完成")
|
||
break
|
||
except Exception as e:
|
||
retry_count += 1
|
||
print(f"YOLO模型加载失败 (尝试 {retry_count}/{max_retries}): {e}")
|
||
if "got an unexpected keyword argument" in str(e) and model_params:
|
||
param_name = str(e).split("'")[-2]
|
||
if param_name in model_params:
|
||
del model_params[param_name]
|
||
time.sleep(2)
|
||
|
||
if ov_model is None:
|
||
raise Exception("无法加载YOLO模型,达到最大重试次数")
|
||
|
||
# 优化模型推理参数
|
||
ov_model.conf = confidence
|
||
|
||
# 打开视频流
|
||
print(f"正在连接视频流: {video_path}")
|
||
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
|
||
cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 60000)
|
||
cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 50000)
|
||
|
||
if not cap.isOpened():
|
||
print("无法打开视频流,尝试强制释放资源...")
|
||
cap.release()
|
||
time.sleep(2)
|
||
raise Exception(f"无法打开视频流: {video_path}")
|
||
|
||
# 设置全局帧队列大小
|
||
global frame_queue, processed_frame_queue
|
||
frame_queue = queue.Queue(maxsize=80)
|
||
processed_frame_queue = queue.Queue(maxsize=40)
|
||
|
||
# 声明mqtt 对象
|
||
mqtt_client = MQTTClient(broker, port, topic)
|
||
|
||
# 与机场维持消息通讯,进而读取最新一条drc消息
|
||
mqtt_client_drc = MQTTClient(mqtt_ip, mqtt_port, mqtt_topic)
|
||
|
||
# 创建并启动线程
|
||
read_thread = Thread(target=read_frames, args=(cap,push_url, frame_queue))
|
||
# process_frames_tricker
|
||
# 读取队列中的frame,并且对frame做二次计算
|
||
process_thread = Thread(target=process_frames_tricker,
|
||
args=(push_url,frame_queue, processed_frame_queue, ov_model, list_func_id, confidence,para,model_cls,chinese_label))
|
||
# process_thread = Thread(target=multi_model_process_frames,
|
||
# args=(
|
||
# frame_queue, processed_frame_queue, ov_model, list_func_id, confidence, para,
|
||
# model_cls,
|
||
# chinese_label))
|
||
# process_thread = Thread(target=process_frames,
|
||
# args=(model_id, frame_queue, processed_frame_queue, ov_model, cls, confidence))
|
||
|
||
write_thread = Thread(target=write_frames,
|
||
args=(processed_frame_queue, None, (1280, 720), mqtt_client, task_id)) # pipe=None
|
||
|
||
read_mqtt_drc_message = Thread(target=read_drc_mqtt, args=(mqtt_client_drc,)) # pipe=None
|
||
|
||
# consumer = Thread(target=consumer_thread, args=(mqtt_client,))
|
||
|
||
read_thread.daemon = True
|
||
process_thread.daemon = True
|
||
write_thread.daemon = True
|
||
read_mqtt_drc_message.daemon = True
|
||
# consumer.daemon = True
|
||
print("开始处理视频流...")
|
||
read_thread.start()
|
||
process_thread.start()
|
||
write_thread.start()
|
||
read_mqtt_drc_message.start()
|
||
# consumer.start()
|
||
|
||
# 等待线程结束
|
||
while getIfAI() and not stop_event.is_set():
|
||
if not (read_thread.is_alive() and process_thread.is_alive()
|
||
and write_thread.is_alive()
|
||
and read_mqtt_drc_message.is_alive()
|
||
# and consumer.is_alive()
|
||
):
|
||
print(f"read_thread.is_alive() {read_thread.is_alive()} "
|
||
f"process_thread.is_alive() {process_thread.is_alive()} "
|
||
f"write_thread.is_alive() {write_thread.is_alive()} ")
|
||
# f"consumer.is_alive() {consumer.is_alive()} ")
|
||
print("检测到某个线程已停止运行,正在终止所有线程...")
|
||
stop_event.set()
|
||
break
|
||
time.sleep(0.1)
|
||
|
||
except Exception as e:
|
||
print(f"发生错误: {e}")
|
||
finally:
|
||
print("正在清理资源...")
|
||
stop_event.set()
|
||
setIfAI(False)
|
||
# 安全关闭MQTT客户端
|
||
if mqtt_client is not None:
|
||
try:
|
||
mqtt_client.close()
|
||
except Exception as e:
|
||
print(f"关闭MQTT客户端时出错: {str(e)}")
|
||
if (read_thread and process_thread
|
||
and write_thread
|
||
):
|
||
timeout = 3
|
||
start_time = time.time()
|
||
while (read_thread.is_alive() or process_thread.is_alive()
|
||
or write_thread.is_alive()
|
||
or read_mqtt_drc_message.is_alive()
|
||
# or consumer.is_alive()
|
||
) and time.time() - start_time < timeout:
|
||
time.sleep(0.1)
|
||
|
||
if cap is not None:
|
||
try:
|
||
cap.release()
|
||
print("视频捕获资源已释放")
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
cv2.destroyAllWindows()
|
||
except:
|
||
pass
|
||
|
||
print("所有资源已清理完毕")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
sn = "1581F6QAD243C00BP71E"
|
||
# video_path = f"rtmp://112.44.103.230:1935/live/{sn}"
|
||
video_path = "rtmp://112.44.103.230:1935/live/123456"
|
||
# FFmpeg 推流地址
|
||
# rtmp = f"rtmp://112.44.103.230:1935/live/{sn}ai"
|
||
mq_ts = time.time_ns() # 时间戳,标识唯一消息
|
||
try:
|
||
startBackAIVideo(mq_ts, video_path, "best.pt", [0, 1, 2, 3, 4], 0.4)
|
||
time.sleep(60)
|
||
except KeyboardInterrupt:
|
||
print("程序被用户中断")
|
||
stopAIVideo()
|
||
except Exception as e:
|
||
print(f"程序异常: {e}")
|
||
# import cv2
|
||
#
|
||
# print(cv2.getBuildInformation())
|