增减数量超限报警功能
This commit is contained in:
parent
be99472837
commit
0b283f6b8c
@ -46,6 +46,7 @@ class ModelData:
|
|||||||
so_path: str
|
so_path: str
|
||||||
repeat_dis: float
|
repeat_dis: float
|
||||||
repeat_time: float
|
repeat_time: float
|
||||||
|
high_count_warn: float
|
||||||
func_description: Optional[str]
|
func_description: Optional[str]
|
||||||
filter_indices: List[int]
|
filter_indices: List[int]
|
||||||
class_indices: List[int]
|
class_indices: List[int]
|
||||||
@ -250,6 +251,7 @@ class ModelConfigDAO:
|
|||||||
aml.py_func,
|
aml.py_func,
|
||||||
aml.repeat_dis,
|
aml.repeat_dis,
|
||||||
aml.repeat_time,
|
aml.repeat_time,
|
||||||
|
aml.high_count_warn,
|
||||||
am.scope,
|
am.scope,
|
||||||
am.yolo_version,
|
am.yolo_version,
|
||||||
am.PATH,
|
am.PATH,
|
||||||
@ -572,6 +574,7 @@ WHERE
|
|||||||
filter_indices=filter_indices,
|
filter_indices=filter_indices,
|
||||||
repeat_dis=repeat_dis,
|
repeat_dis=repeat_dis,
|
||||||
repeat_time=row.get('repeat_time'),
|
repeat_time=row.get('repeat_time'),
|
||||||
|
high_count_warn=row.get('high_count_warn'),
|
||||||
class_indices=row['cls_index'],
|
class_indices=row['cls_index'],
|
||||||
conf=conf,
|
conf=conf,
|
||||||
classes=classes,
|
classes=classes,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from dataclasses import dataclass
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from sympy import false
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -2365,7 +2366,7 @@ async def process_frames(detector: MultiYoloTrtDetectorTrackId, cancel_flag: asy
|
|||||||
time_pr_start = time.time_ns()
|
time_pr_start = time.time_ns()
|
||||||
detections, detections_list, model_para = await detector.predict(frame)
|
detections, detections_list, model_para = await detector.predict(frame)
|
||||||
time_pr_end = time.time_ns()
|
time_pr_end = time.time_ns()
|
||||||
print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}")
|
# print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}")
|
||||||
predict_state = True
|
predict_state = True
|
||||||
if detections:
|
if detections:
|
||||||
print("检测到任何目标")
|
print("检测到任何目标")
|
||||||
@ -2635,8 +2636,8 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps:
|
|||||||
# 如果开起侵限功能,就只显示侵限内的框
|
# 如果开起侵限功能,就只显示侵限内的框
|
||||||
point_x = (x1 + x2) / 2
|
point_x = (x1 + x2) / 2
|
||||||
point_y = (y1 + y2) / 2
|
point_y = (y1 + y2) / 2
|
||||||
print(f"class_name--{class_name}")
|
# print(f"class_name--{class_name}")
|
||||||
print(f"model_class_names: {model_para[0]['model_class_names']}")
|
# print(f"model_class_names: {model_para[0]['model_class_names']}")
|
||||||
|
|
||||||
if class_name not in model_para[0]["model_class_names"]:
|
if class_name not in model_para[0]["model_class_names"]:
|
||||||
continue
|
continue
|
||||||
@ -3298,7 +3299,7 @@ invade_cache_lock = Lock() # 用于保护共享变量的锁
|
|||||||
async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
|
async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
|
||||||
cv_frame_queue: asyncio.Queue,
|
cv_frame_queue: asyncio.Queue,
|
||||||
event_queue: asyncio.Queue = None,
|
event_queue: asyncio.Queue = None,
|
||||||
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
|
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1,high_count_warn: float = -1):
|
||||||
global stats
|
global stats
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||||||
@ -3384,6 +3385,14 @@ async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, ca
|
|||||||
print(f"target_pointtarget_point {len(target_point)}")
|
print(f"target_pointtarget_point {len(target_point)}")
|
||||||
count_item = 0
|
count_item = 0
|
||||||
des_location_result=[]
|
des_location_result=[]
|
||||||
|
|
||||||
|
high_count_warn_status=False
|
||||||
|
high_count_warn_num=0
|
||||||
|
|
||||||
|
if target_point is not None and 0 < high_count_warn < len(target_point):# 触发计数报警
|
||||||
|
high_count_warn_num=len(target_point)
|
||||||
|
high_count_warn_status=True
|
||||||
|
|
||||||
for item in target_point:
|
for item in target_point:
|
||||||
# # 跳过无效的track_id
|
# # 跳过无效的track_id
|
||||||
# 检查是否应该上报该track_id
|
# 检查是否应该上报该track_id
|
||||||
@ -3560,6 +3569,10 @@ async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, ca
|
|||||||
"latitude": cam_latitude
|
"latitude": cam_latitude
|
||||||
},
|
},
|
||||||
"count_message":count_message,
|
"count_message":count_message,
|
||||||
|
"high_count_warn":{
|
||||||
|
"high_count_warn_status":high_count_warn_status,
|
||||||
|
"high_count_warn_num":high_count_warn_num
|
||||||
|
},
|
||||||
"des_location":des_location_result
|
"des_location":des_location_result
|
||||||
}
|
}
|
||||||
await event_queue.put({
|
await event_queue.put({
|
||||||
@ -3988,7 +4001,7 @@ async def start_rtmp_processing(video_url: str, task_id: str, model_configs: Lis
|
|||||||
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
|
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
|
||||||
output_rtmp_url: str,
|
output_rtmp_url: str,
|
||||||
invade_enable: bool, invade_file: str, camera_para_url: str,
|
invade_enable: bool, invade_file: str, camera_para_url: str,
|
||||||
device_height: float, repeat_dis: float, repeat_time: float):
|
device_height: float, repeat_dis: float, repeat_time: float,high_count_warn: float):
|
||||||
# 初始化资源
|
# 初始化资源
|
||||||
# await initialize_resources()
|
# await initialize_resources()
|
||||||
logger.info(f"拉流地址{video_url}")
|
logger.info(f"拉流地址{video_url}")
|
||||||
@ -4153,7 +4166,7 @@ async def start_rtmp_processing(video_url: str, task_id: str, model_configs: Lis
|
|||||||
upload_task = asyncio.create_task(
|
upload_task = asyncio.create_task(
|
||||||
send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_pub_topic,
|
send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_pub_topic,
|
||||||
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis,
|
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis,
|
||||||
repeat_time),
|
repeat_time,high_count_warn),
|
||||||
name=f"send_frame_to_s3_mq_{_}"
|
name=f"send_frame_to_s3_mq_{_}"
|
||||||
)
|
)
|
||||||
upload_tasks.append(upload_task)
|
upload_tasks.append(upload_task)
|
||||||
|
|||||||
@ -470,13 +470,13 @@ class YoLo11TRT(object):
|
|||||||
# 重塑为二维数组
|
# 重塑为二维数组
|
||||||
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
|
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
|
||||||
|
|
||||||
# 调试:打印输出信息
|
# # 调试:打印输出信息
|
||||||
if num > 0:
|
# if num > 0:
|
||||||
logger.info(f"检测到 {num} 个目标,每个目标 {num_values_per_detection} 个值")
|
# logger.info(f"检测到 {num} 个目标,每个目标 {num_values_per_detection} 个值")
|
||||||
logger.info(f"第一个检测框原始值: center_x={pred[0, 0]:.2f}, center_y={pred[0, 1]:.2f}, "
|
# logger.info(f"第一个检测框原始值: center_x={pred[0, 0]:.2f}, center_y={pred[0, 1]:.2f}, "
|
||||||
f"width={pred[0, 2]:.2f}, height={pred[0, 3]:.2f}, conf={pred[0, 4]:.4f}")
|
# f"width={pred[0, 2]:.2f}, height={pred[0, 3]:.2f}, conf={pred[0, 4]:.4f}")
|
||||||
logger.info(f"原始图片尺寸: {origin_w}x{origin_h}")
|
# logger.info(f"原始图片尺寸: {origin_w}x{origin_h}")
|
||||||
logger.info(f"缩放比例: r_w={self.input_w / origin_w:.4f}, r_h={self.input_h / origin_h:.4f}")
|
# logger.info(f"缩放比例: r_w={self.input_w / origin_w:.4f}, r_h={self.input_h / origin_h:.4f}")
|
||||||
|
|
||||||
# 执行NMS
|
# 执行NMS
|
||||||
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
|
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
|
||||||
@ -521,20 +521,20 @@ class YoLo11TRT(object):
|
|||||||
return np.array([])
|
return np.array([])
|
||||||
|
|
||||||
# 调试:打印转换前的坐标
|
# 调试:打印转换前的坐标
|
||||||
if len(boxes) > 0:
|
# if len(boxes) > 0:
|
||||||
logger.info(f"NMS前第一个框: center_x={boxes[0, 0]:.2f}, center_y={boxes[0, 1]:.2f}, "
|
# logger.info(f"NMS前第一个框: center_x={boxes[0, 0]:.2f}, center_y={boxes[0, 1]:.2f}, "
|
||||||
f"width={boxes[0, 2]:.2f}, height={boxes[0, 3]:.2f}, conf={boxes[0, 4]:.4f}")
|
# f"width={boxes[0, 2]:.2f}, height={boxes[0, 3]:.2f}, conf={boxes[0, 4]:.4f}")
|
||||||
|
|
||||||
# 关键修复:使用正确的xywh2xyxy函数
|
# 关键修复:使用正确的xywh2xyxy函数
|
||||||
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
|
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
|
||||||
|
|
||||||
# 调试:打印转换后的坐标
|
# # 调试:打印转换后的坐标
|
||||||
if len(boxes) > 0:
|
# if len(boxes) > 0:
|
||||||
logger.info(f"坐标转换后第一个框: x1={boxes[0, 0]:.2f}, y1={boxes[0, 1]:.2f}, "
|
# logger.info(f"坐标转换后第一个框: x1={boxes[0, 0]:.2f}, y1={boxes[0, 1]:.2f}, "
|
||||||
f"x2={boxes[0, 2]:.2f}, y2={boxes[0, 3]:.2f}")
|
# f"x2={boxes[0, 2]:.2f}, y2={boxes[0, 3]:.2f}")
|
||||||
logger.info(f"框尺寸: {boxes[0, 2] - boxes[0, 0]:.1f}x{boxes[0, 3] - boxes[0, 1]:.1f}")
|
# logger.info(f"框尺寸: {boxes[0, 2] - boxes[0, 0]:.1f}x{boxes[0, 3] - boxes[0, 1]:.1f}")
|
||||||
logger.info(
|
# logger.info(
|
||||||
f"框占图片比例: 宽度={100 * (boxes[0, 2] - boxes[0, 0]) / origin_w:.1f}%, 高度={100 * (boxes[0, 3] - boxes[0, 1]) / origin_h:.1f}%")
|
# f"框占图片比例: 宽度={100 * (boxes[0, 2] - boxes[0, 0]) / origin_w:.1f}%, 高度={100 * (boxes[0, 3] - boxes[0, 1]) / origin_h:.1f}%")
|
||||||
|
|
||||||
# 裁剪坐标到图像边界内
|
# 裁剪坐标到图像边界内
|
||||||
boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
|
boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
|
||||||
|
|||||||
@ -664,6 +664,8 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio
|
|||||||
print("去重的距离:", config.repeat_dis)
|
print("去重的距离:", config.repeat_dis)
|
||||||
repeat_dis = config.repeat_dis
|
repeat_dis = config.repeat_dis
|
||||||
repeat_time = config.repeat_time
|
repeat_time = config.repeat_time
|
||||||
|
high_count_warn = config.high_count_warn
|
||||||
|
print(f"config.high_count_warn {config.high_count_warn}")
|
||||||
model_configs.append(
|
model_configs.append(
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -760,7 +762,7 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio
|
|||||||
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
|
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
|
||||||
push_url,
|
push_url,
|
||||||
invade_enable, invade_file, camera_para_url,
|
invade_enable, invade_file, camera_para_url,
|
||||||
device_height, repeat_dis, repeat_time
|
device_height, repeat_dis, repeat_time,high_count_warn
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理流程异常: {e}")
|
logger.error(f"处理流程异常: {e}")
|
||||||
@ -1846,5 +1848,5 @@ if __name__ == "__main__":
|
|||||||
print("正在安装psutil库...")
|
print("正在安装psutil库...")
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
||||||
|
|
||||||
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True)
|
app.run(host="192.168.110.103", port=12315, debug=False, access_log=True)
|
||||||
# app.run(host="0.0.0.0", workers=3, port=12315)
|
# app.run(host="0.0.0.0", workers=3, port=12315)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user