commit 45fde807d881db9836c866e16f3c50a6dfcfbe26 Author: yooooger <761181201@qq.com> Date: Wed Jul 9 15:34:23 2025 +0800 20250709 diff --git a/ai2/.codemap/main-panel.json b/ai2/.codemap/main-panel.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/ai2/.codemap/main-panel.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/ai2/GDCL.pt b/ai2/GDCL.pt new file mode 100644 index 0000000..85ca803 Binary files /dev/null and b/ai2/GDCL.pt differ diff --git a/ai2/HWRC.pt b/ai2/HWRC.pt new file mode 100644 index 0000000..c24c4e1 Binary files /dev/null and b/ai2/HWRC.pt differ diff --git a/ai2/__pycache__/ai_open_api.cpython-312.pyc b/ai2/__pycache__/ai_open_api.cpython-312.pyc new file mode 100644 index 0000000..3cb2d1b Binary files /dev/null and b/ai2/__pycache__/ai_open_api.cpython-312.pyc differ diff --git a/ai2/__pycache__/color.cpython-312.pyc b/ai2/__pycache__/color.cpython-312.pyc new file mode 100644 index 0000000..b3cd80d Binary files /dev/null and b/ai2/__pycache__/color.cpython-312.pyc differ diff --git a/ai2/__pycache__/cv_video.cpython-311.pyc b/ai2/__pycache__/cv_video.cpython-311.pyc new file mode 100644 index 0000000..6f5c7ce Binary files /dev/null and b/ai2/__pycache__/cv_video.cpython-311.pyc differ diff --git a/ai2/__pycache__/cv_video.cpython-312.pyc b/ai2/__pycache__/cv_video.cpython-312.pyc new file mode 100644 index 0000000..d2ebc7a Binary files /dev/null and b/ai2/__pycache__/cv_video.cpython-312.pyc differ diff --git a/ai2/__pycache__/minio_oss.cpython-311.pyc b/ai2/__pycache__/minio_oss.cpython-311.pyc new file mode 100644 index 0000000..2beca73 Binary files /dev/null and b/ai2/__pycache__/minio_oss.cpython-311.pyc differ diff --git a/ai2/__pycache__/minio_oss.cpython-312.pyc b/ai2/__pycache__/minio_oss.cpython-312.pyc new file mode 100644 index 0000000..289a9c1 Binary files /dev/null and b/ai2/__pycache__/minio_oss.cpython-312.pyc differ diff --git a/ai2/__pycache__/pgadmin_helper.cpython-311.pyc b/ai2/__pycache__/pgadmin_helper.cpython-311.pyc new file mode 100644 index 0000000..05e12dd Binary files /dev/null and b/ai2/__pycache__/pgadmin_helper.cpython-311.pyc differ diff --git a/ai2/__pycache__/pgadmin_helper.cpython-312.pyc b/ai2/__pycache__/pgadmin_helper.cpython-312.pyc new file mode 100644 index 0000000..98a2eb5 Binary files /dev/null and b/ai2/__pycache__/pgadmin_helper.cpython-312.pyc differ diff --git a/ai2/__pycache__/yolo_findaim.cpython-311.pyc b/ai2/__pycache__/yolo_findaim.cpython-311.pyc new file mode 100644 index 0000000..6bc4fef Binary files /dev/null and b/ai2/__pycache__/yolo_findaim.cpython-311.pyc differ diff --git a/ai2/__pycache__/yolo_findaim.cpython-312.pyc b/ai2/__pycache__/yolo_findaim.cpython-312.pyc new file mode 100644 index 0000000..690b5db Binary files /dev/null and b/ai2/__pycache__/yolo_findaim.cpython-312.pyc differ diff --git a/ai2/best.pt b/ai2/best.pt new file mode 100644 index 0000000..79f35ba Binary files /dev/null and b/ai2/best.pt differ diff --git a/ai2/cv_video copy.py b/ai2/cv_video copy.py new file mode 100644 index 0000000..442b488 --- /dev/null +++ b/ai2/cv_video copy.py @@ -0,0 +1,593 @@ +import cv2 +import subprocess +from threading import Thread, Lock, Event +import time +import queue +import numpy as np +import datetime +import os +from ultralytics import YOLO # 导入 Ultralytics YOLO 模型 + +# 全局变量 +ifAI = {'status': False} +deskLock = Lock() +frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小 +processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列 +stop_event = Event() + +def setIfAI(pb1): + deskLock.acquire() + ifAI['status'] = pb1 + deskLock.release() + +def getIfAI(): + return ifAI['status'] + +def stopAIVideo(): + print("正在停止AI视频处理...") + setIfAI(False) + stop_event.set() + + # 等待足够长的时间确保资源释放 + wait_count = 0 + max_wait = 5 # 减少最大等待时间到5秒 + + while stop_event.is_set() and wait_count < max_wait: + time.sleep(0.5) + wait_count += 1 + + if wait_count >= max_wait: + print("警告: 停止AI视频处理超时,强制终止") + # 不使用_thread._interrupt_main(),改用其他方式强制终止 + try: + # 尝试终止可能运行的进程 + import os + import signal + import psutil + + # 查找并终止可能的FFmpeg进程 + current_process = psutil.Process(os.getpid()) + for child in current_process.children(recursive=True): + try: + child_name = child.name().lower() + if 'ffmpeg' in child_name: + print(f"正在终止子进程: {child.pid} ({child_name})") + child.send_signal(signal.SIGTERM) + except: + pass + except: + pass + else: + print("AI视频处理已停止") + +def startAIVideo(video_path, output_url, m1, cls, confidence): + if ifAI['status']: + stopAIVideo() + time.sleep(1) + stop_event.clear() + thread = Thread(target=startAIVideo2, + args=(video_path, output_url, m1, cls, confidence)) + # cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence)) + # cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出 + thread.daemon = True # 守护线程,主程序退出时线程也会退出 + + + thread.start() + # cls2_thread.start() + +def read_frames(cap, frame_queue): + """优化的帧读取线程""" + frame_count = 0 + last_time = time.time() + last_fps_time = time.time() + # 减小目标帧间隔时间,提高读取帧率 + target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间(提高到60fps) + + # 添加连接断开检测 + connection_error_count = 0 + max_connection_errors = 10 # 最多允许连续10次连接错误 + last_successful_read = time.time() + max_read_wait = 30.0 # 30秒无法读取则认为连接断开 + + # 预先丢弃几帧,确保从新帧开始处理 + for _ in range(5): + cap.grab() + + while not stop_event.is_set(): + current_time = time.time() + elapsed_time = current_time - last_time + + # 检查是否长时间无法读取帧 + if current_time - last_successful_read > max_read_wait: + print(f"警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开") + stop_event.set() + break + + # 帧率控制,但更积极地读取 + if elapsed_time < target_time_per_frame: + time.sleep(target_time_per_frame - elapsed_time) + continue + + # 当队列快满时,跳过一些帧以避免延迟累积 + if frame_queue.qsize() > frame_queue.maxsize * 0.8: + # 跳过一些帧 + cap.grab() + last_time = time.time() + continue + + ret, frame = cap.read() + if not ret: + print("拉流错误:无法读取帧") + connection_error_count += 1 + if connection_error_count >= max_connection_errors: + print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...") + stop_event.set() + break + time.sleep(0.5) # 短暂等待后重试 + continue + + # 成功读取了帧,重置错误计数 + connection_error_count = 0 + last_successful_read = time.time() + + frame_count += 1 + if frame_count % 60 == 0: # 每60帧计算一次FPS + current_fps_time = time.time() + fps = 60 / (current_fps_time - last_fps_time) + print(f"拉流FPS: {fps:.2f}") + last_fps_time = current_fps_time + + last_time = time.time() + frame_queue.put((frame, time.time())) # 添加时间戳 + +def process_frames(frame_queue, processed_frame_queue, ov_model, cls, confidence): + """处理帧的线程,添加帧率控制""" + error_count = 0 # 添加错误计数器 + max_errors = 5 # 最大容许错误次数 + frame_count = 0 + last_process_time = time.time() + process_times = [] # 用于计算平均处理时间 + + # 设置YOLO模型配置,提高性能 + ov_model.conf = confidence # 设置置信度阈值 + + # 优化推理性能 + try: + # 导入torch库 + import torch + # 尝试启用ONNX Runtime加速 + ov_model.to('cuda:0' if torch.cuda.is_available() else 'cpu') + # 调整批处理大小为1,减少内存占用 + if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'): + ov_model.args.batch = 1 + # 使用half精度,提高性能 + if torch.cuda.is_available() and hasattr(ov_model, 'model'): + try: + ov_model.model = ov_model.model.half() + except Exception as half_err: + print(f"半精度转换失败: {half_err}") + except Exception as e: + print(f"模型优化配置警告: {e}") + + # 缓存先前的检测结果,用于提高稳定性 + last_results = None + skip_counter = 0 + max_skip = 2 # 最多跳过几帧不处理 + + while not stop_event.is_set(): + try: + if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8: + # 如果输出队列接近满,等待一段时间 + time.sleep(0.01) + continue + + frame, timestamp = frame_queue.get(timeout=0.2) + + # 处理延迟过大的帧 + if time.time() - timestamp > 0.3: # 减少延迟阈值 + continue + + frame_count += 1 + + # 间隔采样,每n帧处理一次,减少计算量 + if skip_counter > 0 and last_results is not None: + skip_counter -= 1 + # 使用上次的检测结果 + annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5) + processed_frame_queue.put((annotated_frame, timestamp)) + continue + + process_start = time.time() + + # 动态调整处理尺寸,根据队列积压情况 + resize_scale = 1.0 + if frame_queue.qsize() > frame_queue.maxsize * 0.7: + resize_scale = 0.4 # 高负载时大幅降低分辨率 + elif frame_queue.qsize() > frame_queue.maxsize * 0.5: + resize_scale = 0.6 # 中等负载时适度降低分辨率 + elif frame_queue.qsize() > frame_queue.maxsize * 0.3: + resize_scale = 0.8 # 轻微负载时轻微降低分辨率 + + # 调整图像尺寸以加快处理 + if resize_scale < 1.0: + process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale) + else: + process_frame = frame + + # 执行推理 + try: + results = ov_model(process_frame, classes=cls, show=False) + last_results = results[0] # 保存检测结果用于后续帧 + + # 如果尺寸调整过,需要将结果转换回原始尺寸 + if resize_scale < 1.0: + # 绘制检测框 + annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5), + (frame.shape[1], frame.shape[0])) + else: + annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5) + + # 在负载高时启用跳帧处理 + if frame_queue.qsize() > frame_queue.maxsize * 0.5: + skip_counter = max_skip + except Exception as infer_err: + print(f"推理错误: {infer_err}") + if last_results is not None: + # 使用上次的结果 + annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5) + else: + # 如果没有上次的结果,简单返回原始帧 + annotated_frame = frame.copy() + + process_end = time.time() + process_times.append(process_end - process_start) + if len(process_times) > 30: + process_times.pop(0) + + if frame_count % 30 == 0: + avg_process_time = sum(process_times) / len(process_times) + fps = 1.0 / avg_process_time if avg_process_time > 0 else 0 + print(f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time*1000:.2f}ms, 队列大小: {frame_queue.qsize()}, 缩放比例: {resize_scale:.2f}") + + processed_frame_queue.put((annotated_frame, timestamp)) + error_count = 0 # 成功处理后重置错误计数 + except queue.Empty: + continue + except Exception as e: + error_count += 1 + print(f"处理帧错误: {e}") + if error_count >= max_errors: + print(f"连续处理错误达到{max_errors}次,正在停止处理...") + stop_event.set() + break + +def write_frames(processed_frame_queue, pipe, size): + """写入帧的线程,添加平滑处理""" + last_write_time = time.time() + target_time_per_frame = 1.0 / 30.0 # 30fps + pipe_error_count = 0 # 添加错误计数 + max_pipe_errors = 3 # 最大容忍错误数 + frame_count = 0 + last_fps_time = time.time() + skipped_frames = 0 + + # 使用队列存储最近几帧,用于在需要时进行插值 + recent_frames = [] + max_recent_frames = 5 # 增加缓存帧数量,提高平滑性 + + # 使用双缓冲机制提高写入速度 + buffer1 = bytearray(size[0] * size[1] * 3) + buffer2 = bytearray(size[0] * size[1] * 3) + current_buffer = buffer1 + + # 帧率控制参数 + min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔 + max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔 + + while not stop_event.is_set(): + try: + # 获取处理后的帧,超时时间较短以便更平滑地处理 + frame, timestamp = processed_frame_queue.get(timeout=0.05) + + # 存储最近的帧用于插值 + recent_frames.append(frame) + if len(recent_frames) > max_recent_frames: + recent_frames.pop(0) + + current_time = time.time() + elapsed = current_time - last_write_time + + # 如果两帧间隔太短,考虑合并或跳过 + if elapsed < min_frame_interval and len(recent_frames) >= 2: + skipped_frames += 1 + continue + + # 如果两帧间隔太长,进行插值平滑 + if elapsed > max_frame_interval and len(recent_frames) >= 2: + # 计算需要插入的帧数 + frames_to_insert = min(int(elapsed / target_time_per_frame), 3) + + for i in range(frames_to_insert): + # 创建插值帧 + weight = (i + 1) / (frames_to_insert + 1) + interpolated_frame = cv2.addWeighted(recent_frames[-2], 1-weight, recent_frames[-1], weight, 0) + + # 切换双缓冲 + current_buffer = buffer2 if current_buffer is buffer1 else buffer1 + + # 高效调整大小并写入 + interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR) + img_bytes = interpolated_resized.tobytes() + + # 写入管道 + pipe.stdin.write(img_bytes) + pipe.stdin.flush() + + # 切换双缓冲 + current_buffer = buffer2 if current_buffer is buffer1 else buffer1 + + # 正常写入当前帧 - 使用高效的调整大小方法 + resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR) + img_bytes = resized_frame.tobytes() + pipe.stdin.write(img_bytes) + pipe.stdin.flush() + + frame_count += 1 + if frame_count % 30 == 0: + current_fps_time = time.time() + fps = 30 / (current_fps_time - last_fps_time) + print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}") + last_fps_time = current_fps_time + skipped_frames = 0 + + last_write_time = time.time() + pipe_error_count = 0 # 成功写入后重置错误计数 + + except queue.Empty: + # 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅 + if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame: + try: + # 创建插值帧 + interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0) + + # 切换双缓冲 + current_buffer = buffer2 if current_buffer is buffer1 else buffer1 + + resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR) + img_bytes = resized_frame.tobytes() + pipe.stdin.write(img_bytes) + pipe.stdin.flush() + last_write_time = time.time() + except Exception: + pass + continue + except Exception as e: + print(f"写入帧错误: {e}") + pipe_error_count += 1 + if pipe_error_count >= max_pipe_errors: + print("FFmpeg管道错误过多,正在终止进程...") + stop_event.set() # 主动结束所有线程 + break + +def startAIVideo2(video_path, output_url, m1, cls, confidence): + rtmp = output_url + setIfAI(True) + + cap = None + pipe = None + read_thread = None + process_thread = None + write_thread = None + ov_model = None + + try: + import os, cv2, torch, time, queue, subprocess + import numpy as np + from threading import Thread, Event + from ultralytics import YOLO + + global frame_queue, processed_frame_queue, stop_event + stop_event = Event() + + os.environ["OMP_NUM_THREADS"] = "4" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + torch.cuda.empty_cache() + + print("预加载YOLO模型...") + model_params = {} + try: + test_model = YOLO(m1) + if hasattr(test_model, "task"): + model_params["task"] = "detect" + if torch.cuda.is_available(): + model_params["half"] = True + import inspect + if "verbose" in inspect.signature(YOLO.__init__).parameters: + model_params["verbose"] = False + except Exception as e: + print(f"参数检测失败,将使用默认参数: {e}") + model_params = {} + + retry_count = 0 + while retry_count < 3: + try: + ov_model = YOLO(m1, **model_params) + dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8) + for _ in range(3): + ov_model(dummy_frame, classes=cls, conf=confidence, show=False) + print("YOLO模型加载成功并预热完成") + break + except Exception as e: + retry_count += 1 + print(f"YOLO模型加载失败(尝试 {retry_count}/3): {e}") + if "unexpected keyword" in str(e): + param = str(e).split("'")[-2] + if param in model_params: + print(f"移除不支持的参数: {param}") + del model_params[param] + time.sleep(2) + + if ov_model is None: + raise Exception("无法加载YOLO模型") + + ov_model.conf = confidence + + print(f"正在连接视频流: {video_path}") + cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) + cap.set(cv2.CAP_PROP_BUFFERSIZE, 5) + cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264')) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080) + + if not cap.isOpened(): + raise Exception(f"无法打开视频流: {video_path}") + + try: + cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0) + except Exception as e: + print(f"无法设置自动曝光参数: {e}") + + frame_queue = queue.Queue(maxsize=80) + processed_frame_queue = queue.Queue(maxsize=40) + + size = (1920, 1080) + sizeStr = f"{size[0]}x{size[1]}" + command = ['ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo', + '-pix_fmt', 'bgr24', '-s', sizeStr, '-r', '30', '-i', '-', + '-c:v', 'h264', '-pix_fmt', 'yuv420p', + '-preset', 'ultrafast', '-tune', 'zerolatency', + '-f', 'flv', '-g', '30', '-bufsize', '4000k', + '-maxrate', '4000k', '-b:v', '2500k', '-vsync', '1', + '-threads', '4', rtmp] + + print(f"启动FFmpeg推流到: {rtmp}") + pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + bufsize=10*1024*1024) + + def monitor_ffmpeg_output(pipe): + while not stop_event.is_set(): + line = pipe.stderr.readline().decode('utf-8', errors='ignore') + if line and ("error" in line.lower()): + print(f"FFmpeg错误: {line.strip()}") + if "Cannot open connection" in line: + stop_event.set() + break + + Thread(target=monitor_ffmpeg_output, args=(pipe,), daemon=True).start() + + def read_frames(cap, frame_queue): + while not stop_event.is_set(): + ret, frame = cap.read() + if not ret: + print("读取失败") + break + timestamp = time.time() + try: + frame_queue.put((timestamp, frame), timeout=1) + except queue.Full: + print("帧队列满,跳帧") + + def process_frames(frame_queue, processed_frame_queue, model, cls, confidence): + while not stop_event.is_set(): + try: + timestamp, frame = frame_queue.get(timeout=1) + result = model.predict(source=frame, classes=cls, conf=confidence, verbose=False) + processed = result[0].plot() + processed_frame_queue.put((timestamp, processed), timeout=1) + except queue.Empty: + continue + except Exception as e: + print(f"处理帧错误: {e}") + + def write_frames(processed_frame_queue, pipe, size): + last_timestamp = 0 + while not stop_event.is_set(): + try: + timestamp, frame = processed_frame_queue.get(timeout=1) + if timestamp < last_timestamp: + print(f"跳过闪回帧 {timestamp} < {last_timestamp}") + continue + last_timestamp = timestamp + frame = cv2.resize(frame, size) + pipe.stdin.write(frame.tobytes()) + except Exception as e: + print(f"写入帧错误: {e}") + break + + read_thread = Thread(target=read_frames, args=(cap, frame_queue), daemon=True, name="ReadThread") + process_thread = Thread(target=process_frames, args=(frame_queue, processed_frame_queue, ov_model, cls, confidence), daemon=True, name="ProcessThread") + write_thread = Thread(target=write_frames, args=(processed_frame_queue, pipe, size), daemon=True, name="WriteThread") + + print("开始推流处理...") + read_thread.start() + process_thread.start() + write_thread.start() + + last_check = time.time() + while getIfAI() and not stop_event.is_set(): + if not all([t.is_alive() for t in [read_thread, process_thread, write_thread]]): + print("检测到线程停止,退出") + stop_event.set() + break + if pipe.poll() is not None: + print("FFmpeg退出") + stop_event.set() + break + if time.time() - last_check > 30: + print(f"输入队列: {frame_queue.qsize()}/{frame_queue.maxsize} | 输出队列: {processed_frame_queue.qsize()}/{processed_frame_queue.maxsize}") + last_check = time.time() + time.sleep(0.1) + + except Exception as e: + print(f"错误: {e}") + finally: + print("清理资源...") + stop_event.set() + setIfAI(False) + + for t in [read_thread, process_thread, write_thread]: + if t and t.is_alive(): + t.join(timeout=2) + + try: + if cap: cap.release() + if pipe: + try: + import signal + os.kill(pipe.pid, signal.SIGTERM) + except: pass + pipe.stdin.close() + pipe.terminate() + try: + pipe.wait(timeout=2) + except: + pipe.kill() + except Exception as e: + print(f"释放资源时出错: {e}") + + try: + cv2.destroyAllWindows() + except: + pass + print("资源释放完毕") + + +if __name__ == '__main__': + sn = "1581F6QAD243C00BP71E" + video_path = f"rtmp://222.212.85.86:1935/live/{sn}" + # FFmpeg 推流地址 + rtmp = f"rtmp://222.212.85.86:1935/live/{sn}ai" + + try: + startAIVideo2(video_path, rtmp, "best.pt", [0, 1, 2, 3, 4],0.4) + except KeyboardInterrupt: + print("程序被用户中断") + stopAIVideo() + except Exception as e: + print(f"程序异常: {e}") + diff --git a/ai2/cv_video.py b/ai2/cv_video.py new file mode 100644 index 0000000..7a73c05 --- /dev/null +++ b/ai2/cv_video.py @@ -0,0 +1,603 @@ +from threading import Thread, Lock, Event +import time +import queue +from ultralytics import YOLO # 导入 Ultralytics YOLO 模型 +import os, cv2, torch, time, queue, subprocess +import numpy as np + +# 全局变量 +ifAI = {'status': False} +deskLock = Lock() +frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小 +processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列 +stop_event = Event() + +def setIfAI(pb1): + deskLock.acquire() + ifAI['status'] = pb1 + deskLock.release() + +def getIfAI(): + return ifAI['status'] + +def stopAIVideo(): + print("正在停止AI视频处理...") + setIfAI(False) + stop_event.set() + + # 等待足够长的时间确保资源释放 + wait_count = 0 + max_wait = 5 # 减少最大等待时间到5秒 + + while stop_event.is_set() and wait_count < max_wait: + time.sleep(0.5) + wait_count += 1 + + if wait_count >= max_wait: + print("警告: 停止AI视频处理超时,强制终止") + # 不使用_thread._interrupt_main(),改用其他方式强制终止 + try: + # 尝试终止可能运行的进程 + import os + import signal + import psutil + + # 查找并终止可能的FFmpeg进程 + current_process = psutil.Process(os.getpid()) + for child in current_process.children(recursive=True): + try: + child_name = child.name().lower() + if 'ffmpeg' in child_name: + print(f"正在终止子进程: {child.pid} ({child_name})") + child.send_signal(signal.SIGTERM) + except: + pass + except: + pass + else: + print("AI视频处理已停止") + +def startAIVideo(video_path, output_url, m1, cls, confidence): + if ifAI['status']: + stopAIVideo() + time.sleep(1) + stop_event.clear() + thread = Thread(target=startAIVideo2, + args=(video_path, output_url, m1, cls, confidence)) + # cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence)) + # cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出 + thread.daemon = True # 守护线程,主程序退出时线程也会退出 + + + thread.start() + # cls2_thread.start() + +def read_frames(cap, frame_queue): + """优化的帧读取线程""" + frame_count = 0 + last_time = time.time() + last_fps_time = time.time() + # 减小目标帧间隔时间,提高读取帧率 + target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间(提高到60fps) + + # 添加连接断开检测 + connection_error_count = 0 + max_connection_errors = 10 # 最多允许连续10次连接错误 + last_successful_read = time.time() + max_read_wait = 30.0 # 30秒无法读取则认为连接断开 + + # 预先丢弃几帧,确保从新帧开始处理 + for _ in range(5): + cap.grab() + + while not stop_event.is_set(): + current_time = time.time() + elapsed_time = current_time - last_time + + # 检查是否长时间无法读取帧 + if current_time - last_successful_read > max_read_wait: + print(f"警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开") + stop_event.set() + break + + # 帧率控制,但更积极地读取 + if elapsed_time < target_time_per_frame: + time.sleep(target_time_per_frame - elapsed_time) + continue + + # 当队列快满时,跳过一些帧以避免延迟累积 + if frame_queue.qsize() > frame_queue.maxsize * 0.8: + # 跳过一些帧 + cap.grab() + last_time = time.time() + continue + + ret, frame = cap.read() + if not ret: + print("拉流错误:无法读取帧") + connection_error_count += 1 + if connection_error_count >= max_connection_errors: + print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...") + stop_event.set() + break + time.sleep(0.5) # 短暂等待后重试 + continue + + # 成功读取了帧,重置错误计数 + connection_error_count = 0 + last_successful_read = time.time() + + frame_count += 1 + if frame_count % 60 == 0: # 每60帧计算一次FPS + current_fps_time = time.time() + fps = 60 / (current_fps_time - last_fps_time) + print(f"拉流FPS: {fps:.2f}") + last_fps_time = current_fps_time + + last_time = time.time() + frame_queue.put((frame, time.time())) # 添加时间戳 + +def process_frames(frame_queue, processed_frame_queue, ov_model, cls, confidence): + """处理帧的线程,添加帧率控制""" + error_count = 0 # 添加错误计数器 + max_errors = 5 # 最大容许错误次数 + frame_count = 0 + process_times = [] # 用于计算平均处理时间 + + # 设置YOLO模型配置,提高性能 + ov_model.conf = confidence # 设置置信度阈值 + + # 优化推理性能 + try: + # 导入torch库 + import torch + # 尝试启用ONNX Runtime加速 + ov_model.to('cuda:0' if torch.cuda.is_available() else 'cpu') + # 调整批处理大小为1,减少内存占用 + if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'): + ov_model.args.batch = 1 + # 使用half精度,提高性能 + if torch.cuda.is_available() and hasattr(ov_model, 'model'): + try: + ov_model.model = ov_model.model.half() + except Exception as half_err: + print(f"半精度转换失败: {half_err}") + except Exception as e: + print(f"模型优化配置警告: {e}") + + # 缓存先前的检测结果,用于提高稳定性 + last_results = None + skip_counter = 0 + max_skip = 2 # 最多跳过几帧不处理 + + while not stop_event.is_set(): + try: + if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8: + # 如果输出队列接近满,等待一段时间 + time.sleep(0.01) + continue + + frame, timestamp = frame_queue.get(timeout=0.2) + + # 处理延迟过大的帧 + if time.time() - timestamp > 0.3: # 减少延迟阈值 + continue + + frame_count += 1 + + # 间隔采样,每n帧处理一次,减少计算量 + if skip_counter > 0 and last_results is not None: + skip_counter -= 1 + # 使用上次的检测结果 + annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5) + processed_frame_queue.put((annotated_frame, timestamp)) + continue + + process_start = time.time() + + # 动态调整处理尺寸,根据队列积压情况 + resize_scale = 1.0 + if frame_queue.qsize() > frame_queue.maxsize * 0.7: + resize_scale = 0.4 # 高负载时大幅降低分辨率 + elif frame_queue.qsize() > frame_queue.maxsize * 0.5: + resize_scale = 0.6 # 中等负载时适度降低分辨率 + elif frame_queue.qsize() > frame_queue.maxsize * 0.3: + resize_scale = 0.8 # 轻微负载时轻微降低分辨率 + + # 调整图像尺寸以加快处理 + if resize_scale < 1.0: + process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale) + else: + process_frame = frame + + # 执行推理 + try: + results = ov_model(process_frame, classes=cls, show=False) + last_results = results[0] # 保存检测结果用于后续帧 + + # 如果尺寸调整过,需要将结果转换回原始尺寸 + if resize_scale < 1.0: + # 绘制检测框 + annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5), + (frame.shape[1], frame.shape[0])) + else: + annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5) + + # 在负载高时启用跳帧处理 + if frame_queue.qsize() > frame_queue.maxsize * 0.5: + skip_counter = max_skip + except Exception as infer_err: + print(f"推理错误: {infer_err}") + if last_results is not None: + # 使用上次的结果 + annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5) + else: + # 如果没有上次的结果,简单返回原始帧 + annotated_frame = frame.copy() + + process_end = time.time() + process_times.append(process_end - process_start) + if len(process_times) > 30: + process_times.pop(0) + + if frame_count % 30 == 0: + avg_process_time = sum(process_times) / len(process_times) + fps = 1.0 / avg_process_time if avg_process_time > 0 else 0 + print(f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time*1000:.2f}ms, 队列大小: {frame_queue.qsize()}, 缩放比例: {resize_scale:.2f}") + + processed_frame_queue.put((annotated_frame, timestamp)) + error_count = 0 # 成功处理后重置错误计数 + except queue.Empty: + continue + except Exception as e: + error_count += 1 + print(f"处理帧错误: {e}") + if error_count >= max_errors: + print(f"连续处理错误达到{max_errors}次,正在停止处理...") + stop_event.set() + break + +def write_frames(processed_frame_queue, pipe, size): + """写入帧的线程,添加平滑处理""" + last_write_time = time.time() + target_time_per_frame = 1.0 / 30.0 # 30fps + pipe_error_count = 0 # 添加错误计数 + max_pipe_errors = 3 # 最大容忍错误数 + frame_count = 0 + last_fps_time = time.time() + skipped_frames = 0 + + # 使用队列存储最近几帧,用于在需要时进行插值 + recent_frames = [] + max_recent_frames = 5 # 增加缓存帧数量,提高平滑性 + + # 使用双缓冲机制提高写入速度 + buffer1 = bytearray(size[0] * size[1] * 3) + buffer2 = bytearray(size[0] * size[1] * 3) + current_buffer = buffer1 + + # 帧率控制参数 + min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔 + max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔 + + while not stop_event.is_set(): + try: + # 获取处理后的帧,超时时间较短以便更平滑地处理 + frame, timestamp = processed_frame_queue.get(timeout=0.05) + + # 存储最近的帧用于插值 + recent_frames.append(frame) + if len(recent_frames) > max_recent_frames: + recent_frames.pop(0) + + current_time = time.time() + elapsed = current_time - last_write_time + + # 如果两帧间隔太短,考虑合并或跳过 + if elapsed < min_frame_interval and len(recent_frames) >= 2: + skipped_frames += 1 + continue + + # 如果两帧间隔太长,进行插值平滑 + if elapsed > max_frame_interval and len(recent_frames) >= 2: + # 计算需要插入的帧数 + frames_to_insert = min(int(elapsed / target_time_per_frame), 3) + + for i in range(frames_to_insert): + # 创建插值帧 + weight = (i + 1) / (frames_to_insert + 1) + interpolated_frame = cv2.addWeighted(recent_frames[-2], 1-weight, recent_frames[-1], weight, 0) + + # 切换双缓冲 + current_buffer = buffer2 if current_buffer is buffer1 else buffer1 + + # 高效调整大小并写入 + interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR) + img_bytes = interpolated_resized.tobytes() + + # 写入管道 + pipe.stdin.write(img_bytes) + pipe.stdin.flush() + + # 切换双缓冲 + current_buffer = buffer2 if current_buffer is buffer1 else buffer1 + + # 正常写入当前帧 - 使用高效的调整大小方法 + resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR) + img_bytes = resized_frame.tobytes() + pipe.stdin.write(img_bytes) + pipe.stdin.flush() + + frame_count += 1 + if frame_count % 30 == 0: + current_fps_time = time.time() + fps = 30 / (current_fps_time - last_fps_time) + print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}") + last_fps_time = current_fps_time + skipped_frames = 0 + + last_write_time = time.time() + pipe_error_count = 0 # 成功写入后重置错误计数 + + except queue.Empty: + # 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅 + if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame: + try: + # 创建插值帧 + interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0) + + # 切换双缓冲 + current_buffer = buffer2 if current_buffer is buffer1 else buffer1 + + resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR) + img_bytes = resized_frame.tobytes() + pipe.stdin.write(img_bytes) + pipe.stdin.flush() + last_write_time = time.time() + except Exception: + pass + continue + except Exception as e: + print(f"写入帧错误: {e}") + pipe_error_count += 1 + if pipe_error_count >= max_pipe_errors: + print("FFmpeg管道错误过多,正在终止进程...") + stop_event.set() # 主动结束所有线程 + break + +def startAIVideo2(video_path, output_url, m1, cls, confidence): + rtmp = output_url + setIfAI(True) + + cap = None + pipe = None + read_thread = None + process_thread = None + write_thread = None + ov_model = None + + try: + + global frame_queue, processed_frame_queue, stop_event + stop_event = Event() + + os.environ["OMP_NUM_THREADS"] = "4" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + torch.cuda.empty_cache() + + print("预加载YOLO模型...") + model_params = {} + try: + test_model = YOLO(m1) + if hasattr(test_model, "task"): + model_params["task"] = "detect" + if torch.cuda.is_available(): + model_params["half"] = True + import inspect + if "verbose" in inspect.signature(YOLO.__init__).parameters: + model_params["verbose"] = False + except Exception as e: + print(f"参数检测失败,将使用默认参数: {e}") + model_params = {} + + retry_count = 0 + while retry_count < 3: + try: + ov_model = YOLO(m1, **model_params) + dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8) + for _ in range(3): + ov_model(dummy_frame, classes=cls, conf=confidence, show=False) + print("YOLO模型加载成功并预热完成") + break + except Exception as e: + retry_count += 1 + print(f"YOLO模型加载失败(尝试 {retry_count}/3): {e}") + if "unexpected keyword" in str(e): + param = str(e).split("'")[-2] + if param in model_params: + print(f"移除不支持的参数: {param}") + del model_params[param] + time.sleep(2) + + if ov_model is None: + raise Exception("无法加载YOLO模型") + + ov_model.conf = confidence + + print(f"正在连接视频流: {video_path}") + cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) + cap.set(cv2.CAP_PROP_BUFFERSIZE, 5) + cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264')) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080) + + if not cap.isOpened(): + raise Exception(f"无法打开视频流: {video_path}") + + try: + cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0) + except Exception as e: + print(f"无法设置自动曝光参数: {e}") + + frame_queue = queue.Queue(maxsize=80) + processed_frame_queue = queue.Queue(maxsize=40) + + size = (1920, 1080) + sizeStr = f"{size[0]}x{size[1]}" + + command = [ + 'ffmpeg', '-y', + '-f', 'rawvideo', '-vcodec', 'rawvideo', + '-pix_fmt', 'bgr24', + '-s', sizeStr, + '-r', '30', + '-i', '-', + '-c:v', 'libx264', + '-preset', 'ultrafast', + '-tune', 'zerolatency', + '-x264-params', 'sei=0', + '-pix_fmt', 'yuv420p', + '-f', 'flv', + '-g', '30', + '-keyint_min', '30', + '-sc_threshold', '0', + '-b:v', '2500k', + '-maxrate', '3000k', + '-bufsize', '3000k', + '-threads', '4', + '-vsync', '1', + rtmp + ] + + + print(f"启动FFmpeg推流到: {rtmp}") + pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + bufsize=10*1024*1024) + + def monitor_ffmpeg_output(pipe): + while not stop_event.is_set(): + line = pipe.stderr.readline().decode('utf-8', errors='ignore') + if line and ("error" in line.lower()): + print(f"FFmpeg错误: {line.strip()}") + if "Cannot open connection" in line: + stop_event.set() + break + + Thread(target=monitor_ffmpeg_output, args=(pipe,), daemon=True).start() + + def read_frames(cap, frame_queue): + while not stop_event.is_set(): + ret, frame = cap.read() + if not ret: + print("读取失败") + break + timestamp = time.time() + try: + frame_queue.put((timestamp, frame), timeout=1) + except queue.Full: + print("帧队列满,跳帧") + + def process_frames(frame_queue, processed_frame_queue, model, cls, confidence): + while not stop_event.is_set(): + try: + timestamp, frame = frame_queue.get(timeout=1) + result = model.predict(source=frame, classes=cls, conf=confidence, verbose=False) + processed = result[0].plot() + processed_frame_queue.put((timestamp, processed), timeout=1) + except queue.Empty: + continue + except Exception as e: + print(f"处理帧错误: {e}") + + def write_frames(processed_frame_queue, pipe, size): + last_timestamp = 0 + while not stop_event.is_set(): + try: + timestamp, frame = processed_frame_queue.get(timeout=1) + if timestamp < last_timestamp: + print(f"跳过闪回帧 {timestamp} < {last_timestamp}") + continue + last_timestamp = timestamp + frame = cv2.resize(frame, size) + pipe.stdin.write(frame.tobytes()) + except Exception as e: + print(f"写入帧错误: {e}") + break + + read_thread = Thread(target=read_frames, args=(cap, frame_queue), daemon=True, name="ReadThread") + process_thread = Thread(target=process_frames, args=(frame_queue, processed_frame_queue, ov_model, cls, confidence), daemon=True, name="ProcessThread") + write_thread = Thread(target=write_frames, args=(processed_frame_queue, pipe, size), daemon=True, name="WriteThread") + + print("开始推流处理...") + read_thread.start() + process_thread.start() + write_thread.start() + + last_check = time.time() + while getIfAI() and not stop_event.is_set(): + if not all([t.is_alive() for t in [read_thread, process_thread, write_thread]]): + print("检测到线程停止,退出") + stop_event.set() + break + if pipe.poll() is not None: + print("FFmpeg退出") + stop_event.set() + break + if time.time() - last_check > 30: + print(f"输入队列: {frame_queue.qsize()}/{frame_queue.maxsize} | 输出队列: {processed_frame_queue.qsize()}/{processed_frame_queue.maxsize}") + last_check = time.time() + time.sleep(0.1) + + except Exception as e: + print(f"错误: {e}") + finally: + print("清理资源...") + stop_event.set() + setIfAI(False) + + for t in [read_thread, process_thread, write_thread]: + if t and t.is_alive(): + t.join(timeout=2) + + try: + if cap: cap.release() + if pipe: + try: + import signal + os.kill(pipe.pid, signal.SIGTERM) + except: pass + pipe.stdin.close() + pipe.terminate() + try: + pipe.wait(timeout=2) + except: + pipe.kill() + except Exception as e: + print(f"释放资源时出错: {e}") + + try: + cv2.destroyAllWindows() + except: + pass + print("资源释放完毕") + + +if __name__ == '__main__': + sn = "1581F6QAD243C00BP71E" + video_path = f"rtmp://222.212.85.86:1935/live/{sn}" + # FFmpeg 推流地址 + rtmp = f"rtmp://222.212.85.86:1935/live/{sn}ai" + + try: + startAIVideo2(video_path, rtmp, "best.pt", [0, 1, 2, 3, 4],0.4) + except KeyboardInterrupt: + print("程序被用户中断") + stopAIVideo() + except Exception as e: + print(f"程序异常: {e}") + diff --git a/ai2/fire.pt b/ai2/fire.pt new file mode 100644 index 0000000..fad2dff Binary files /dev/null and b/ai2/fire.pt differ diff --git a/ai2/gdaq.pt b/ai2/gdaq.pt new file mode 100644 index 0000000..309fae0 Binary files /dev/null and b/ai2/gdaq.pt differ diff --git a/ai2/last.pt b/ai2/last.pt new file mode 100644 index 0000000..497e15f Binary files /dev/null and b/ai2/last.pt differ diff --git a/ai2/minio_helper.py b/ai2/minio_helper.py new file mode 100644 index 0000000..3ec3a2b --- /dev/null +++ b/ai2/minio_helper.py @@ -0,0 +1,48 @@ + +import os +from minio import Minio +from minio.error import S3Error + +bucket="300bdf2b-a150-406e-be63-d28bd29b409f" +# 替换为你的MinIO服务器地址、访问密钥和秘密密钥 +def getClient(): + minio_client = Minio( + "222.212.85.86:9000", + access_key="WuRenJi", + secure=False, + secret_key="WRJ@2024",) + return minio_client + +def getPath2(object): + #dir="C:/sy/movies/" + dir=os.getcwd()+"/" + baseName=object + s1=baseName.rfind("/") + dir2=(dir+baseName[0:s1+1]).replace("/","\\") + fName=baseName[s1+1:int(len(baseName))] + os.makedirs(dir2, exist_ok=True) + file_path = os.path.join(dir2, fName) + return file_path + +def upLoad(obj,path): + try: + minio_client=getClient() + minio_client.fput_object(bucket, obj, path) + return True + except S3Error as e: + return False + +def downLoad(obj): + path=getPath2(obj) + if os.path.exists(path): + return path +# 从MinIO的存储桶和对象名称下载 + try: + minio_client=getClient() + minio_client.fget_object(bucket, obj, path) + return path + except S3Error as e: + return "" + +if __name__ == '__main__': + upLoad("aaa/yolo_api.py","yolo_api.py") diff --git a/ai2/smoke.pt b/ai2/smoke.pt new file mode 100644 index 0000000..39be32c Binary files /dev/null and b/ai2/smoke.pt differ diff --git a/ai2/trash.pt b/ai2/trash.pt new file mode 100644 index 0000000..257f1b3 Binary files /dev/null and b/ai2/trash.pt differ diff --git a/ai2/yanwu2.pt b/ai2/yanwu2.pt new file mode 100644 index 0000000..85da7fd Binary files /dev/null and b/ai2/yanwu2.pt differ diff --git a/ai2/yolo11n.pt b/ai2/yolo11n.pt new file mode 100644 index 0000000..45b273b Binary files /dev/null and b/ai2/yolo11n.pt differ diff --git a/ai2/yolo_api copy.py b/ai2/yolo_api copy.py new file mode 100644 index 0000000..9dd7fd7 --- /dev/null +++ b/ai2/yolo_api copy.py @@ -0,0 +1,509 @@ +from sanic import Sanic, json +from sanic.response import json as json_response +from sanic.exceptions import Unauthorized, NotFound, SanicException +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +import uuid +import logging +import asyncio +import traceback +from datetime import datetime +from cv_video import startAIVideo,stopAIVideo,getIfAI +from sanic_cors import CORS + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# 服务状态标志 +service_status = {"is_healthy": True, "last_error": None, "error_time": None} + +# 配置类 +class Config: + VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa" + MAX_ACTIVE_TASKS = 10 + DEFAULT_CONFIDENCE = 0.5 + RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒) + + +@dataclass +class StreamRequest: + source_url: str + push_url: str + model_path: str + detect_classes: List[str] + confidence: float = Config.DEFAULT_CONFIDENCE + + def validate(self) -> None: + """验证请求参数""" + if not self.source_url or not self.push_url: + raise ValueError("Source URL and Push URL are required") + + if not self.detect_classes: + raise ValueError("At least one detection class must be specified") + if not 0 < self.confidence < 1: + raise ValueError("Confidence must be between 0 and 1") + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest': + try: + instance = cls( + source_url=data['source_url'], + push_url=data['push_url'], + model_path=data['model_path'], + detect_classes=data['detect_classes'], + confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE) + ) + instance.validate() + return instance + except KeyError as e: + raise ValueError(f"Missing required field: {str(e)}") + +class TaskManager: + def __init__(self): + self.active_tasks: Dict[str, Dict[str, Any]] = {} + self.task_status: Dict[str, str] = {} + self.task_timestamps: Dict[str, datetime] = {} + + def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None: + """添加新任务""" + if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS: + raise ValueError("Maximum number of active tasks reached") + + self.active_tasks[task_id] = task_info + self.task_status[task_id] = "running" + self.task_timestamps[task_id] = datetime.now() + logger.info(f"Task {task_id} started") + + def remove_task(self, task_id: str) -> None: + """移除任务""" + if task_id in self.active_tasks: + del self.active_tasks[task_id] + del self.task_status[task_id] + del self.task_timestamps[task_id] + logger.info(f"Task {task_id} removed") + + def get_task_info(self, task_id: str) -> Dict[str, Any]: + """获取任务信息""" + if task_id not in self.active_tasks: + raise NotFound("Task not found") + + return { + "task_info": self.active_tasks[task_id], + "status": self.task_status[task_id], + "start_time": self.task_timestamps[task_id].isoformat() + } + + def check_tasks_health(self) -> Dict[str, str]: + """检查任务健康状态""" + unhealthy_tasks = {} + for task_id in list(self.active_tasks.keys()): + # 检查任务是否还在运行(通过getIfAI()函数) + if not getIfAI(): + unhealthy_tasks[task_id] = "stopped" + logger.warning(f"Task {task_id} appears to be stopped unexpectedly") + return unhealthy_tasks + + def mark_all_tasks_as_stopped(self): + """标记所有任务为已停止状态""" + for task_id in list(self.active_tasks.keys()): + self.task_status[task_id] = "stopped" + logger.warning("已将所有任务标记为停止状态") + +app = Sanic("YoloStreamService") +CORS(app) +task_manager = TaskManager() + +async def safe_stop_ai_video(): + """安全地停止AI视频处理,带有错误处理和恢复机制""" + try: + await asyncio.to_thread(stopAIVideo) + return True + except Exception as e: + error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + + # 标记服务状态为不健康 + service_status["is_healthy"] = False + service_status["last_error"] = str(e) + service_status["error_time"] = datetime.now().isoformat() + + # 强制结束所有任务 + task_manager.mark_all_tasks_as_stopped() + + # 尝试通过其他方式杀死可能存在的进程 + try: + import os + import signal + import psutil + + current_process = psutil.Process(os.getpid()) + # 查找并终止ffmpeg子进程 + for child in current_process.children(recursive=True): + try: + child_name = child.name().lower() + if 'ffmpeg' in child_name: + logger.info(f"强制终止子进程: {child.pid} ({child_name})") + child.send_signal(signal.SIGTERM) + except Exception as child_e: + logger.error(f"终止子进程出错: {str(child_e)}") + except Exception as kill_e: + logger.error(f"尝试清理进程时出错: {str(kill_e)}") + + # 等待一段时间让系统恢复 + await asyncio.sleep(Config.RESTART_DELAY) + + # 重置服务状态 + service_status["is_healthy"] = True + return False + +def verify_token(request) -> None: + """验证请求token""" + token = request.headers.get('X-API-Token') + if not token or token != Config.VALID_TOKEN: + logger.warning("Invalid token attempt") + raise Unauthorized("Invalid token") + +@app.post("/ai/stream/detect") +async def start_detection(request): + try: + verify_token(request) + + # 检查服务健康状态 + if not service_status["is_healthy"]: + logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}") + # 尝试恢复服务 + service_status["is_healthy"] = True + + # 先停止所有现有任务 + for task_id in list(task_manager.active_tasks.keys()): + logger.info(f"停止现有任务 {task_id} 以启动新任务") + try: + success = await safe_stop_ai_video() + if success: + task_manager.remove_task(task_id) + else: + logger.warning(f"无法正常停止任务 {task_id},但仍将继续") + task_manager.mark_all_tasks_as_stopped() + except Exception as e: + logger.error(f"停止任务时出错: {e}") + # 继续执行,尝试启动新任务 + + # 解析并验证请求数据 + stream_request = StreamRequest.from_dict(request.json) + + task_id = str(uuid.uuid4()) + # 修正视频源地址 + new_source_url = stream_request.source_url.replace("222.212.85.86", "192.168.10.5") + new_push_url = stream_request.push_url.replace("222.212.85.86", "192.168.10.5") + # 启动YOLO检测 + try: + await asyncio.to_thread( + startAIVideo, + new_source_url, + new_push_url, + stream_request.model_path, + stream_request.detect_classes, + stream_request.confidence + ) + except Exception as e: + logger.error(f"启动AI视频处理失败: {e}") + return json_response({ + "status": "error", + "message": f"Failed to start AI video processing: {str(e)}" + }, status=500) + + # 记录任务信息 + task_manager.add_task(task_id, { + "source_url": stream_request.source_url, + "push_url": stream_request.push_url, + "model_path": stream_request.model_path, + "detect_classes": stream_request.detect_classes, + "confidence": stream_request.confidence + }) + + return json_response({ + "status": "success", + "task_id": task_id, + "message": "Detection started successfully" + }) + + except ValueError as e: + logger.error(f"Validation error: {str(e)}") + return json_response({"status": "error", "message": str(e)}, status=400) + except Exception as e: + logger.error(f"Unexpected error: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.post("/ai/stream/") +async def stop_detection(request, task_id: str): + try: + verify_token(request) + + # 检查任务是否存在 + try: + task_info = task_manager.get_task_info(task_id) + except NotFound: + return json_response({"status": "error", "message": "Task not found"}, status=404) + + # 停止AI视频处理,使用安全的停止方法 + success = await safe_stop_ai_video() + + # 即使停止失败,也要移除任务 + task_manager.remove_task(task_id) + + if not success: + logger.warning("虽然停止过程出现错误,但任务已被标记为结束") + return json_response({ + "status": "warning", + "message": "Task removal completed with warnings" + }) + + return json_response({ + "status": "success", + "message": "Detection stopped successfully" + }) + except NotFound as e: + return json_response({"status": "error", "message": str(e)}, status=404) + except Exception as e: + logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True) + # 尝试标记任务为停止状态 + try: + if task_id in task_manager.task_status: + task_manager.task_status[task_id] = "error_during_stop" + except: + pass + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/stream/") +async def get_task_status(request, task_id: str): + try: + verify_token(request) + task_info = task_manager.get_task_info(task_id) + + # 检查任务是否真的在运行 + if not getIfAI() and task_info["status"] == "running": + task_info["status"] = "stopped_unexpectedly" + logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止") + + return json_response({ + "status": "success", + "task_id": task_id, + **task_info + }) + except NotFound as e: + return json_response({"status": "error", "message": str(e)}, status=404) + except Exception as e: + logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/stream/tasks") +async def list_tasks(request): + """获取所有活动任务列表""" + try: + verify_token(request) + + # 检查所有任务的健康状态 + unhealthy_tasks = task_manager.check_tasks_health() + for task_id, status in unhealthy_tasks.items(): + if task_id in task_manager.task_status: + task_manager.task_status[task_id] = status + + tasks = { + task_id: task_manager.get_task_info(task_id) + for task_id in task_manager.active_tasks.keys() + } + return json_response({ + "status": "success", + "tasks": tasks + }) + except Exception as e: + logger.error(f"Error listing tasks: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.post("/ai/stream/stopTasks") +async def stop_all_detections(request): + """停止所有活动任务""" + try: + verify_token(request) + + if not task_manager.active_tasks: + return json_response({ + "status": "success", + "message": "No active tasks to stop" + }) + + # 停止所有任务 + success = await safe_stop_ai_video() + + # 无论成功与否,都移除所有任务 + for task_id in list(task_manager.active_tasks.keys()): + task_manager.remove_task(task_id) + + if not success: + return json_response({ + "status": "warning", + "message": "Tasks stopped with warnings" + }) + + return json_response({ + "status": "success", + "message": "All detections stopped successfully" + }) + except Exception as e: + logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True) + # 尝试标记所有任务为停止状态 + task_manager.mark_all_tasks_as_stopped() + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/health") +async def health_check(request): + """服务健康检查端点""" + try: + # 不需要验证token,这个接口可以用于监控系统检查服务状态 + unhealthy_tasks = task_manager.check_tasks_health() + + return json_response({ + "status": "success", + "service": "running" if service_status["is_healthy"] else "degraded", + "active_tasks": len(task_manager.active_tasks), + "unhealthy_tasks": unhealthy_tasks, + "last_error": service_status["last_error"], + "error_time": service_status["error_time"], + "timestamp": datetime.now().isoformat() + }) + except Exception as e: + logger.error(f"健康检查失败: {str(e)}", exc_info=True) + return json_response({ + "status": "error", + "service": "degraded", + "error": str(e), + "timestamp": datetime.now().isoformat() + }, status=500) + +@app.route("/ai/reset", methods=["POST"]) +async def reset_service(request): + """重置服务状态,清理所有任务和进程""" + try: + verify_token(request) + + # 尝试停止AI视频处理 + await safe_stop_ai_video() + + # 清理所有任务 + for task_id in list(task_manager.active_tasks.keys()): + task_manager.remove_task(task_id) + + # 重置服务状态 + service_status["is_healthy"] = True + service_status["last_error"] = None + service_status["error_time"] = None + + # 尝试清理可能存在的僵尸进程 + try: + import os + import signal + import psutil + + current_process = psutil.Process(os.getpid()) + zombie_count = 0 + + for child in current_process.children(recursive=True): + try: + if child.status() == psutil.STATUS_ZOMBIE: + zombie_count += 1 + child.send_signal(signal.SIGKILL) + except: + pass + + return json_response({ + "status": "success", + "message": f"Service reset successful. Cleaned {zombie_count} zombie processes." + }) + except Exception as e: + logger.error(f"清理僵尸进程时出错: {e}") + return json_response({ + "status": "warning", + "message": "Service reset with warnings" + }) + + except Exception as e: + logger.error(f"重置服务时出错: {str(e)}", exc_info=True) + return json_response({ + "status": "error", + "message": f"Failed to reset service: {str(e)}" + }, status=500) + +@app.route("/ai/stream/restart/", methods=["POST"]) +async def restart_task(request, task_id: str): + """重启指定任务""" + try: + verify_token(request) + + # 获取任务信息 + try: + task_info = task_manager.get_task_info(task_id)["task_info"] + except NotFound: + return json_response({"status": "error", "message": "Task not found"}, status=404) + + # 先停止任务 + success = await safe_stop_ai_video() + task_manager.remove_task(task_id) + + if not success: + logger.warning("停止任务出现问题,尝试继续重启") + + # 重新启动任务 + new_task_id = str(uuid.uuid4()) + + try: + await asyncio.to_thread( + startAIVideo, + task_info["source_url"], + task_info["push_url"], + task_info["model_path"], + task_info["detect_classes"], + task_info["confidence"] + ) + + # 记录新任务信息 + task_manager.add_task(new_task_id, task_info) + + return json_response({ + "status": "success", + "old_task_id": task_id, + "new_task_id": new_task_id, + "message": "Task restarted successfully" + }) + except Exception as e: + logger.error(f"重启任务失败: {e}") + return json_response({ + "status": "error", + "message": f"Failed to restart task: {str(e)}" + }, status=500) + + except Exception as e: + logger.error(f"重启任务时出错: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +if __name__ == "__main__": + # 保证服务启动前没有残留任务 + try: + stopAIVideo() + print("服务启动前清理完成") + except: + print("服务启动前清理失败,但仍将继续") + + # 安装psutil库,用于进程管理 + try: + import psutil + except ImportError: + import subprocess + import sys + print("正在安装psutil库...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"]) + + app.run(host="0.0.0.0", port=12315, debug=False, access_log=True) diff --git a/ai2/yolo_api.py b/ai2/yolo_api.py new file mode 100644 index 0000000..9dd7fd7 --- /dev/null +++ b/ai2/yolo_api.py @@ -0,0 +1,509 @@ +from sanic import Sanic, json +from sanic.response import json as json_response +from sanic.exceptions import Unauthorized, NotFound, SanicException +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +import uuid +import logging +import asyncio +import traceback +from datetime import datetime +from cv_video import startAIVideo,stopAIVideo,getIfAI +from sanic_cors import CORS + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# 服务状态标志 +service_status = {"is_healthy": True, "last_error": None, "error_time": None} + +# 配置类 +class Config: + VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa" + MAX_ACTIVE_TASKS = 10 + DEFAULT_CONFIDENCE = 0.5 + RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒) + + +@dataclass +class StreamRequest: + source_url: str + push_url: str + model_path: str + detect_classes: List[str] + confidence: float = Config.DEFAULT_CONFIDENCE + + def validate(self) -> None: + """验证请求参数""" + if not self.source_url or not self.push_url: + raise ValueError("Source URL and Push URL are required") + + if not self.detect_classes: + raise ValueError("At least one detection class must be specified") + if not 0 < self.confidence < 1: + raise ValueError("Confidence must be between 0 and 1") + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest': + try: + instance = cls( + source_url=data['source_url'], + push_url=data['push_url'], + model_path=data['model_path'], + detect_classes=data['detect_classes'], + confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE) + ) + instance.validate() + return instance + except KeyError as e: + raise ValueError(f"Missing required field: {str(e)}") + +class TaskManager: + def __init__(self): + self.active_tasks: Dict[str, Dict[str, Any]] = {} + self.task_status: Dict[str, str] = {} + self.task_timestamps: Dict[str, datetime] = {} + + def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None: + """添加新任务""" + if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS: + raise ValueError("Maximum number of active tasks reached") + + self.active_tasks[task_id] = task_info + self.task_status[task_id] = "running" + self.task_timestamps[task_id] = datetime.now() + logger.info(f"Task {task_id} started") + + def remove_task(self, task_id: str) -> None: + """移除任务""" + if task_id in self.active_tasks: + del self.active_tasks[task_id] + del self.task_status[task_id] + del self.task_timestamps[task_id] + logger.info(f"Task {task_id} removed") + + def get_task_info(self, task_id: str) -> Dict[str, Any]: + """获取任务信息""" + if task_id not in self.active_tasks: + raise NotFound("Task not found") + + return { + "task_info": self.active_tasks[task_id], + "status": self.task_status[task_id], + "start_time": self.task_timestamps[task_id].isoformat() + } + + def check_tasks_health(self) -> Dict[str, str]: + """检查任务健康状态""" + unhealthy_tasks = {} + for task_id in list(self.active_tasks.keys()): + # 检查任务是否还在运行(通过getIfAI()函数) + if not getIfAI(): + unhealthy_tasks[task_id] = "stopped" + logger.warning(f"Task {task_id} appears to be stopped unexpectedly") + return unhealthy_tasks + + def mark_all_tasks_as_stopped(self): + """标记所有任务为已停止状态""" + for task_id in list(self.active_tasks.keys()): + self.task_status[task_id] = "stopped" + logger.warning("已将所有任务标记为停止状态") + +app = Sanic("YoloStreamService") +CORS(app) +task_manager = TaskManager() + +async def safe_stop_ai_video(): + """安全地停止AI视频处理,带有错误处理和恢复机制""" + try: + await asyncio.to_thread(stopAIVideo) + return True + except Exception as e: + error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + + # 标记服务状态为不健康 + service_status["is_healthy"] = False + service_status["last_error"] = str(e) + service_status["error_time"] = datetime.now().isoformat() + + # 强制结束所有任务 + task_manager.mark_all_tasks_as_stopped() + + # 尝试通过其他方式杀死可能存在的进程 + try: + import os + import signal + import psutil + + current_process = psutil.Process(os.getpid()) + # 查找并终止ffmpeg子进程 + for child in current_process.children(recursive=True): + try: + child_name = child.name().lower() + if 'ffmpeg' in child_name: + logger.info(f"强制终止子进程: {child.pid} ({child_name})") + child.send_signal(signal.SIGTERM) + except Exception as child_e: + logger.error(f"终止子进程出错: {str(child_e)}") + except Exception as kill_e: + logger.error(f"尝试清理进程时出错: {str(kill_e)}") + + # 等待一段时间让系统恢复 + await asyncio.sleep(Config.RESTART_DELAY) + + # 重置服务状态 + service_status["is_healthy"] = True + return False + +def verify_token(request) -> None: + """验证请求token""" + token = request.headers.get('X-API-Token') + if not token or token != Config.VALID_TOKEN: + logger.warning("Invalid token attempt") + raise Unauthorized("Invalid token") + +@app.post("/ai/stream/detect") +async def start_detection(request): + try: + verify_token(request) + + # 检查服务健康状态 + if not service_status["is_healthy"]: + logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}") + # 尝试恢复服务 + service_status["is_healthy"] = True + + # 先停止所有现有任务 + for task_id in list(task_manager.active_tasks.keys()): + logger.info(f"停止现有任务 {task_id} 以启动新任务") + try: + success = await safe_stop_ai_video() + if success: + task_manager.remove_task(task_id) + else: + logger.warning(f"无法正常停止任务 {task_id},但仍将继续") + task_manager.mark_all_tasks_as_stopped() + except Exception as e: + logger.error(f"停止任务时出错: {e}") + # 继续执行,尝试启动新任务 + + # 解析并验证请求数据 + stream_request = StreamRequest.from_dict(request.json) + + task_id = str(uuid.uuid4()) + # 修正视频源地址 + new_source_url = stream_request.source_url.replace("222.212.85.86", "192.168.10.5") + new_push_url = stream_request.push_url.replace("222.212.85.86", "192.168.10.5") + # 启动YOLO检测 + try: + await asyncio.to_thread( + startAIVideo, + new_source_url, + new_push_url, + stream_request.model_path, + stream_request.detect_classes, + stream_request.confidence + ) + except Exception as e: + logger.error(f"启动AI视频处理失败: {e}") + return json_response({ + "status": "error", + "message": f"Failed to start AI video processing: {str(e)}" + }, status=500) + + # 记录任务信息 + task_manager.add_task(task_id, { + "source_url": stream_request.source_url, + "push_url": stream_request.push_url, + "model_path": stream_request.model_path, + "detect_classes": stream_request.detect_classes, + "confidence": stream_request.confidence + }) + + return json_response({ + "status": "success", + "task_id": task_id, + "message": "Detection started successfully" + }) + + except ValueError as e: + logger.error(f"Validation error: {str(e)}") + return json_response({"status": "error", "message": str(e)}, status=400) + except Exception as e: + logger.error(f"Unexpected error: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.post("/ai/stream/") +async def stop_detection(request, task_id: str): + try: + verify_token(request) + + # 检查任务是否存在 + try: + task_info = task_manager.get_task_info(task_id) + except NotFound: + return json_response({"status": "error", "message": "Task not found"}, status=404) + + # 停止AI视频处理,使用安全的停止方法 + success = await safe_stop_ai_video() + + # 即使停止失败,也要移除任务 + task_manager.remove_task(task_id) + + if not success: + logger.warning("虽然停止过程出现错误,但任务已被标记为结束") + return json_response({ + "status": "warning", + "message": "Task removal completed with warnings" + }) + + return json_response({ + "status": "success", + "message": "Detection stopped successfully" + }) + except NotFound as e: + return json_response({"status": "error", "message": str(e)}, status=404) + except Exception as e: + logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True) + # 尝试标记任务为停止状态 + try: + if task_id in task_manager.task_status: + task_manager.task_status[task_id] = "error_during_stop" + except: + pass + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/stream/") +async def get_task_status(request, task_id: str): + try: + verify_token(request) + task_info = task_manager.get_task_info(task_id) + + # 检查任务是否真的在运行 + if not getIfAI() and task_info["status"] == "running": + task_info["status"] = "stopped_unexpectedly" + logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止") + + return json_response({ + "status": "success", + "task_id": task_id, + **task_info + }) + except NotFound as e: + return json_response({"status": "error", "message": str(e)}, status=404) + except Exception as e: + logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/stream/tasks") +async def list_tasks(request): + """获取所有活动任务列表""" + try: + verify_token(request) + + # 检查所有任务的健康状态 + unhealthy_tasks = task_manager.check_tasks_health() + for task_id, status in unhealthy_tasks.items(): + if task_id in task_manager.task_status: + task_manager.task_status[task_id] = status + + tasks = { + task_id: task_manager.get_task_info(task_id) + for task_id in task_manager.active_tasks.keys() + } + return json_response({ + "status": "success", + "tasks": tasks + }) + except Exception as e: + logger.error(f"Error listing tasks: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.post("/ai/stream/stopTasks") +async def stop_all_detections(request): + """停止所有活动任务""" + try: + verify_token(request) + + if not task_manager.active_tasks: + return json_response({ + "status": "success", + "message": "No active tasks to stop" + }) + + # 停止所有任务 + success = await safe_stop_ai_video() + + # 无论成功与否,都移除所有任务 + for task_id in list(task_manager.active_tasks.keys()): + task_manager.remove_task(task_id) + + if not success: + return json_response({ + "status": "warning", + "message": "Tasks stopped with warnings" + }) + + return json_response({ + "status": "success", + "message": "All detections stopped successfully" + }) + except Exception as e: + logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True) + # 尝试标记所有任务为停止状态 + task_manager.mark_all_tasks_as_stopped() + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/health") +async def health_check(request): + """服务健康检查端点""" + try: + # 不需要验证token,这个接口可以用于监控系统检查服务状态 + unhealthy_tasks = task_manager.check_tasks_health() + + return json_response({ + "status": "success", + "service": "running" if service_status["is_healthy"] else "degraded", + "active_tasks": len(task_manager.active_tasks), + "unhealthy_tasks": unhealthy_tasks, + "last_error": service_status["last_error"], + "error_time": service_status["error_time"], + "timestamp": datetime.now().isoformat() + }) + except Exception as e: + logger.error(f"健康检查失败: {str(e)}", exc_info=True) + return json_response({ + "status": "error", + "service": "degraded", + "error": str(e), + "timestamp": datetime.now().isoformat() + }, status=500) + +@app.route("/ai/reset", methods=["POST"]) +async def reset_service(request): + """重置服务状态,清理所有任务和进程""" + try: + verify_token(request) + + # 尝试停止AI视频处理 + await safe_stop_ai_video() + + # 清理所有任务 + for task_id in list(task_manager.active_tasks.keys()): + task_manager.remove_task(task_id) + + # 重置服务状态 + service_status["is_healthy"] = True + service_status["last_error"] = None + service_status["error_time"] = None + + # 尝试清理可能存在的僵尸进程 + try: + import os + import signal + import psutil + + current_process = psutil.Process(os.getpid()) + zombie_count = 0 + + for child in current_process.children(recursive=True): + try: + if child.status() == psutil.STATUS_ZOMBIE: + zombie_count += 1 + child.send_signal(signal.SIGKILL) + except: + pass + + return json_response({ + "status": "success", + "message": f"Service reset successful. Cleaned {zombie_count} zombie processes." + }) + except Exception as e: + logger.error(f"清理僵尸进程时出错: {e}") + return json_response({ + "status": "warning", + "message": "Service reset with warnings" + }) + + except Exception as e: + logger.error(f"重置服务时出错: {str(e)}", exc_info=True) + return json_response({ + "status": "error", + "message": f"Failed to reset service: {str(e)}" + }, status=500) + +@app.route("/ai/stream/restart/", methods=["POST"]) +async def restart_task(request, task_id: str): + """重启指定任务""" + try: + verify_token(request) + + # 获取任务信息 + try: + task_info = task_manager.get_task_info(task_id)["task_info"] + except NotFound: + return json_response({"status": "error", "message": "Task not found"}, status=404) + + # 先停止任务 + success = await safe_stop_ai_video() + task_manager.remove_task(task_id) + + if not success: + logger.warning("停止任务出现问题,尝试继续重启") + + # 重新启动任务 + new_task_id = str(uuid.uuid4()) + + try: + await asyncio.to_thread( + startAIVideo, + task_info["source_url"], + task_info["push_url"], + task_info["model_path"], + task_info["detect_classes"], + task_info["confidence"] + ) + + # 记录新任务信息 + task_manager.add_task(new_task_id, task_info) + + return json_response({ + "status": "success", + "old_task_id": task_id, + "new_task_id": new_task_id, + "message": "Task restarted successfully" + }) + except Exception as e: + logger.error(f"重启任务失败: {e}") + return json_response({ + "status": "error", + "message": f"Failed to restart task: {str(e)}" + }, status=500) + + except Exception as e: + logger.error(f"重启任务时出错: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +if __name__ == "__main__": + # 保证服务启动前没有残留任务 + try: + stopAIVideo() + print("服务启动前清理完成") + except: + print("服务启动前清理失败,但仍将继续") + + # 安装psutil库,用于进程管理 + try: + import psutil + except ImportError: + import subprocess + import sys + print("正在安装psutil库...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"]) + + app.run(host="0.0.0.0", port=12315, debug=False, access_log=True) diff --git a/ai2/yolo_api.zip b/ai2/yolo_api.zip new file mode 100644 index 0000000..7d4dd8a Binary files /dev/null and b/ai2/yolo_api.zip differ diff --git a/ai2/yolo_api_HANGZOUAPI.py b/ai2/yolo_api_HANGZOUAPI.py new file mode 100644 index 0000000..b7c75a5 --- /dev/null +++ b/ai2/yolo_api_HANGZOUAPI.py @@ -0,0 +1,540 @@ +from sanic import Sanic, json +from sanic.response import json as json_response +from sanic.exceptions import Unauthorized, NotFound, SanicException +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +import uuid +import logging +import asyncio +import traceback +from datetime import datetime +from cv_video import startAIVideo,stopAIVideo,getIfAI +from sanic_cors import CORS + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# 服务状态标志 +service_status = {"is_healthy": True, "last_error": None, "error_time": None} + +# 配置类 +class Config: + VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa" + MAX_ACTIVE_TASKS = 10 + DEFAULT_CONFIDENCE = 0.5 + RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒) + + +@dataclass +class StreamRequest: + source_url: str + push_url: str + model_path: str + detect_classes: List[str] + confidence: float = Config.DEFAULT_CONFIDENCE + + def validate(self) -> None: + """验证请求参数""" + if not self.source_url or not self.push_url: + raise ValueError("Source URL and Push URL are required") + + if not self.detect_classes: + raise ValueError("At least one detection class must be specified") + if not 0 < self.confidence < 1: + raise ValueError("Confidence must be between 0 and 1") + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest': + try: + instance = cls( + source_url=data['source_url'], + push_url=data['push_url'], + model_path=data['model_path'], + detect_classes=data['detect_classes'], + confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE) + ) + instance.validate() + return instance + except KeyError as e: + raise ValueError(f"Missing required field: {str(e)}") + +class TaskManager: + def __init__(self): + self.active_tasks: Dict[str, Dict[str, Any]] = {} + self.task_status: Dict[str, str] = {} + self.task_timestamps: Dict[str, datetime] = {} + + def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None: + """添加新任务""" + if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS: + raise ValueError("Maximum number of active tasks reached") + + self.active_tasks[task_id] = task_info + self.task_status[task_id] = "running" + self.task_timestamps[task_id] = datetime.now() + logger.info(f"Task {task_id} started") + + def remove_task(self, task_id: str) -> None: + """移除任务""" + if task_id in self.active_tasks: + del self.active_tasks[task_id] + del self.task_status[task_id] + del self.task_timestamps[task_id] + logger.info(f"Task {task_id} removed") + + def get_task_info(self, task_id: str) -> Dict[str, Any]: + """获取任务信息""" + if task_id not in self.active_tasks: + raise NotFound("Task not found") + + return { + "task_info": self.active_tasks[task_id], + "status": self.task_status[task_id], + "start_time": self.task_timestamps[task_id].isoformat() + } + + def check_tasks_health(self) -> Dict[str, str]: + """检查任务健康状态""" + unhealthy_tasks = {} + for task_id in list(self.active_tasks.keys()): + # 检查任务是否还在运行(通过getIfAI()函数) + if not getIfAI(): + unhealthy_tasks[task_id] = "stopped" + logger.warning(f"Task {task_id} appears to be stopped unexpectedly") + return unhealthy_tasks + + def mark_all_tasks_as_stopped(self): + """标记所有任务为已停止状态""" + for task_id in list(self.active_tasks.keys()): + self.task_status[task_id] = "stopped" + logger.warning("已将所有任务标记为停止状态") + +app = Sanic("YoloStreamService") +CORS(app) +task_manager = TaskManager() + +async def safe_stop_ai_video(): + """安全地停止AI视频处理,带有错误处理和恢复机制""" + try: + await asyncio.to_thread(stopAIVideo) + return True + except Exception as e: + error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + + # 标记服务状态为不健康 + service_status["is_healthy"] = False + service_status["last_error"] = str(e) + service_status["error_time"] = datetime.now().isoformat() + + # 强制结束所有任务 + task_manager.mark_all_tasks_as_stopped() + + # 尝试通过其他方式杀死可能存在的进程 + try: + import os + import signal + import psutil + + current_process = psutil.Process(os.getpid()) + # 查找并终止ffmpeg子进程 + for child in current_process.children(recursive=True): + try: + child_name = child.name().lower() + if 'ffmpeg' in child_name: + logger.info(f"强制终止子进程: {child.pid} ({child_name})") + child.send_signal(signal.SIGTERM) + except Exception as child_e: + logger.error(f"终止子进程出错: {str(child_e)}") + except Exception as kill_e: + logger.error(f"尝试清理进程时出错: {str(kill_e)}") + + # 等待一段时间让系统恢复 + await asyncio.sleep(Config.RESTART_DELAY) + + # 重置服务状态 + service_status["is_healthy"] = True + return False + +def verify_token(request) -> None: + """验证请求token""" + token = request.headers.get('X-API-Token') + if not token or token != Config.VALID_TOKEN: + logger.warning("Invalid token attempt") + raise Unauthorized("Invalid token") + +@app.post("/ai/stream/detect") +async def start_detection(request): + try: + verify_token(request) + + # 检查服务健康状态 + if not service_status["is_healthy"]: + logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}") + service_status["is_healthy"] = True # 尝试恢复 + + # 先停止所有现有任务 + for task_id in list(task_manager.active_tasks.keys()): + logger.info(f"停止现有任务 {task_id} 以启动新任务") + try: + success = await safe_stop_ai_video() + if success: + task_manager.remove_task(task_id) + else: + logger.warning(f"无法正常停止任务 {task_id},但仍将继续") + task_manager.mark_all_tasks_as_stopped() + except Exception as e: + logger.error(f"停止任务时出错: {e}") + + # 解析并验证请求数据 + stream_request = StreamRequest.from_dict(request.json) + task_id = str(uuid.uuid4()) + + # 如果是“yanwu.pt”模型,调用外部接口获取 liveUrl + if stream_request.model_path == "yanwu.pt": + try: + import aiohttp + import tempfile + + async with aiohttp.ClientSession() as session: + # 创建一个临时空白文件 + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + tmpfile.write(b'') + + with open(tmpfile.name, 'rb') as f: + form_data = aiohttp.FormData() + form_data.add_field("file", f, filename="empty.txt", content_type="text/plain") + form_data.add_field("model", "yanwu") + form_data.add_field("taskType", "2") + form_data.add_field("source", stream_request.source_url) + form_data.add_field("notifyUrl", f"{stream_request.source_url}+ai") + + url = "https://flightcontrol.huaiying-xunjian.com/prod-api/third/api/v1/task/startTask" + async with session.post(url, data=form_data) as resp: + if resp.status != 200: + raise Exception(f"外部服务状态码异常: {resp.status}") + result = await resp.json() + if result.get("code") != 0 or not result.get("data", {}).get("liveUrl"): + raise Exception(f"接口响应错误: {result}") + live_url = result["data"]["liveUrl"] + logger.info(f"外部接口返回推流地址: {live_url}") + stream_request.push_url = live_url # 替换推流地址 + except Exception as ext: + logger.error(f"调用外部直播任务接口失败: {ext}") + return json_response({ + "status": "error", + "message": f"调用直播任务接口失败: {str(ext)}" + }, status=500) + else: + # 启动YOLO检测 + try: + await asyncio.to_thread( + startAIVideo, + stream_request.source_url, + stream_request.push_url, + stream_request.model_path, + stream_request.detect_classes, + stream_request.confidence + ) + except Exception as e: + logger.error(f"启动AI视频处理失败: {e}") + return json_response({ + "status": "error", + "message": f"Failed to start AI video processing: {str(e)}" + }, status=500) + + # 记录任务信息 + task_manager.add_task(task_id, { + "source_url": stream_request.source_url, + "push_url": stream_request.push_url, + "model_path": stream_request.model_path, + "detect_classes": stream_request.detect_classes, + "confidence": stream_request.confidence + }) + + return json_response({ + "status": "success", + "task_id": task_id, + "message": "Detection started successfully" + }) + + except ValueError as e: + logger.error(f"Validation error: {str(e)}") + return json_response({"status": "error", "message": str(e)}, status=400) + except Exception as e: + logger.error(f"Unexpected error: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + + +@app.post("/ai/stream/") +async def stop_detection(request, task_id: str): + try: + verify_token(request) + + # 检查任务是否存在 + try: + task_info = task_manager.get_task_info(task_id) + except NotFound: + return json_response({"status": "error", "message": "Task not found"}, status=404) + + # 停止AI视频处理,使用安全的停止方法 + success = await safe_stop_ai_video() + + # 即使停止失败,也要移除任务 + task_manager.remove_task(task_id) + + if not success: + logger.warning("虽然停止过程出现错误,但任务已被标记为结束") + return json_response({ + "status": "warning", + "message": "Task removal completed with warnings" + }) + + return json_response({ + "status": "success", + "message": "Detection stopped successfully" + }) + except NotFound as e: + return json_response({"status": "error", "message": str(e)}, status=404) + except Exception as e: + logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True) + # 尝试标记任务为停止状态 + try: + if task_id in task_manager.task_status: + task_manager.task_status[task_id] = "error_during_stop" + except: + pass + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/stream/") +async def get_task_status(request, task_id: str): + try: + verify_token(request) + task_info = task_manager.get_task_info(task_id) + + # 检查任务是否真的在运行 + if not getIfAI() and task_info["status"] == "running": + task_info["status"] = "stopped_unexpectedly" + logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止") + + return json_response({ + "status": "success", + "task_id": task_id, + **task_info + }) + except NotFound as e: + return json_response({"status": "error", "message": str(e)}, status=404) + except Exception as e: + logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/stream/tasks") +async def list_tasks(request): + """获取所有活动任务列表""" + try: + verify_token(request) + + # 检查所有任务的健康状态 + unhealthy_tasks = task_manager.check_tasks_health() + for task_id, status in unhealthy_tasks.items(): + if task_id in task_manager.task_status: + task_manager.task_status[task_id] = status + + tasks = { + task_id: task_manager.get_task_info(task_id) + for task_id in task_manager.active_tasks.keys() + } + return json_response({ + "status": "success", + "tasks": tasks + }) + except Exception as e: + logger.error(f"Error listing tasks: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.post("/ai/stream/stopTasks") +async def stop_all_detections(request): + """停止所有活动任务""" + try: + verify_token(request) + + if not task_manager.active_tasks: + return json_response({ + "status": "success", + "message": "No active tasks to stop" + }) + + # 停止所有任务 + success = await safe_stop_ai_video() + + # 无论成功与否,都移除所有任务 + for task_id in list(task_manager.active_tasks.keys()): + task_manager.remove_task(task_id) + + if not success: + return json_response({ + "status": "warning", + "message": "Tasks stopped with warnings" + }) + + return json_response({ + "status": "success", + "message": "All detections stopped successfully" + }) + except Exception as e: + logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True) + # 尝试标记所有任务为停止状态 + task_manager.mark_all_tasks_as_stopped() + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +@app.get("/ai/health") +async def health_check(request): + """服务健康检查端点""" + try: + # 不需要验证token,这个接口可以用于监控系统检查服务状态 + unhealthy_tasks = task_manager.check_tasks_health() + + return json_response({ + "status": "success", + "service": "running" if service_status["is_healthy"] else "degraded", + "active_tasks": len(task_manager.active_tasks), + "unhealthy_tasks": unhealthy_tasks, + "last_error": service_status["last_error"], + "error_time": service_status["error_time"], + "timestamp": datetime.now().isoformat() + }) + except Exception as e: + logger.error(f"健康检查失败: {str(e)}", exc_info=True) + return json_response({ + "status": "error", + "service": "degraded", + "error": str(e), + "timestamp": datetime.now().isoformat() + }, status=500) + +@app.route("/ai/reset", methods=["POST"]) +async def reset_service(request): + """重置服务状态,清理所有任务和进程""" + try: + verify_token(request) + + # 尝试停止AI视频处理 + await safe_stop_ai_video() + + # 清理所有任务 + for task_id in list(task_manager.active_tasks.keys()): + task_manager.remove_task(task_id) + + # 重置服务状态 + service_status["is_healthy"] = True + service_status["last_error"] = None + service_status["error_time"] = None + + # 尝试清理可能存在的僵尸进程 + try: + import os + import signal + import psutil + + current_process = psutil.Process(os.getpid()) + zombie_count = 0 + + for child in current_process.children(recursive=True): + try: + if child.status() == psutil.STATUS_ZOMBIE: + zombie_count += 1 + child.send_signal(signal.SIGKILL) + except: + pass + + return json_response({ + "status": "success", + "message": f"Service reset successful. Cleaned {zombie_count} zombie processes." + }) + except Exception as e: + logger.error(f"清理僵尸进程时出错: {e}") + return json_response({ + "status": "warning", + "message": "Service reset with warnings" + }) + + except Exception as e: + logger.error(f"重置服务时出错: {str(e)}", exc_info=True) + return json_response({ + "status": "error", + "message": f"Failed to reset service: {str(e)}" + }, status=500) + +@app.route("/ai/stream/restart/", methods=["POST"]) +async def restart_task(request, task_id: str): + """重启指定任务""" + try: + verify_token(request) + + # 获取任务信息 + try: + task_info = task_manager.get_task_info(task_id)["task_info"] + except NotFound: + return json_response({"status": "error", "message": "Task not found"}, status=404) + + # 先停止任务 + success = await safe_stop_ai_video() + task_manager.remove_task(task_id) + + if not success: + logger.warning("停止任务出现问题,尝试继续重启") + + # 重新启动任务 + new_task_id = str(uuid.uuid4()) + try: + await asyncio.to_thread( + startAIVideo, + task_info["source_url"], + task_info["push_url"], + task_info["model_path"], + task_info["detect_classes"], + task_info["confidence"] + ) + + # 记录新任务信息 + task_manager.add_task(new_task_id, task_info) + + return json_response({ + "status": "success", + "old_task_id": task_id, + "new_task_id": new_task_id, + "message": "Task restarted successfully" + }) + except Exception as e: + logger.error(f"重启任务失败: {e}") + return json_response({ + "status": "error", + "message": f"Failed to restart task: {str(e)}" + }, status=500) + + except Exception as e: + logger.error(f"重启任务时出错: {str(e)}", exc_info=True) + return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + +if __name__ == "__main__": + # 保证服务启动前没有残留任务 + try: + stopAIVideo() + print("服务启动前清理完成") + except: + print("服务启动前清理失败,但仍将继续") + + # 安装psutil库,用于进程管理 + try: + import psutil + except ImportError: + import subprocess + import sys + print("正在安装psutil库...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"]) + + app.run(host="0.0.0.0", port=12315, debug=False, access_log=True) diff --git a/even/A_lot_of.py b/even/A_lot_of.py new file mode 100644 index 0000000..3a4a6d3 --- /dev/null +++ b/even/A_lot_of.py @@ -0,0 +1,355 @@ +import threading +import os +from ultralytics import YOLO +import cv2 +import threading +import os +from ultralytics import YOLO +import datetime +#@yoooger +#----------------------------------------------------------------------------------------------------------------------------------------------- +# 定义颜色映射 +colors = { + "0": (0, 255, 0), # 绿色 + "1": (0, 0, 255), # 蓝色 + "2": (255, 0, 0), # 红色 + "3": (255, 255, 0), # 黄色 + "4": (0, 255, 255), # 青色 + "5": (255, 0, 255), # 紫色 + "6": (128, 0, 0), # 紫色 + "7": (0, 128, 0), # 绿色 + "8": (0, 0, 128), # 蓝色 + "9": (128, 128, 0), # 黄色 + "10": (128, 0, 128), # 紫色 + "11": (0, 128, 128), # 青色 +} +#----------------------------------------------------------------------------------------------------------------------------------------------- +def draw_and_save(frame, boxes, save_folder, frame_number, model_id): + """绘制检测框并保存当前帧""" + + for box in boxes: + class_id, (x1, y1, x2, y2) = box + color = colors[class_id] # 根据类别选择颜色 + cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) + + # 保存当前帧到指定文件夹 + cv2.imwrite(os.path.join(save_folder, f"model_{model_id}_frame_{frame_number}.jpg"), frame) +#----------------------------------------------------------------------------------------------------------------------------------------------- +def process_images(image_folder_path, model, save_folder, model_id, classes_str, event): + """处理文件夹中的图片,检测物体并触发自定义事件""" + os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在) + classes = list(map(int, classes_str.split(','))) # 将逗号分隔的字符串转换为整数列表 + + # 回调函数用于绘制并保存帧 + def callback(frame, boxes, save_folder, frame_number): + draw_and_save(frame, boxes, save_folder, frame_number, model_id) + + event.subscribe(callback) # 订阅绘制和保存作业 + + # 遍历文件夹中的每个图片文件 + for frame_number, image_name in enumerate(os.listdir(image_folder_path)): + + + image_path = os.path.join(image_folder_path, image_name) + frame = cv2.imread(image_path) # 读取图片 + if frame is None: + continue # 跳过无法读取的图片 + + # 使用 YOLO 进行推理 + results = model(frame) # 输入当前帧 + boxes = [] # 定义一个空列表,用于存放检测框 + + # 处理 YOLO 的输出 + for result in results: + for box in result.boxes: + class_id = int(box.cls) # 获取类别ID + if class_id in classes: # 检查类别是否在指定类别数组中 + x1, y1, x2, y2 = map(int, box.xyxy[0]) # 获取框坐标 + boxes.append((class_id, (x1, y1, x2, y2))) + +#------------------------------------------------------------------------------------------------------------------------------------- +def process_video(video_path, model, save_folder, model_id, classes_str, event): + """处理视频,检测物体并触发自定义事件""" + os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在) + classes = list(map(int, classes_str.split(',')))# 将逗号分隔的字符串转换为整数列表 + cap = cv2.VideoCapture(video_path)# 打开视频文件,使用cv2.VideoCapture()函数,截取视频的帧 + skip_frames = 0 # 初始化跳帧数 + + # 修改callback函数以接受3个参数 + def callback(frame, boxes, save_folder): + draw_and_save(frame, boxes, save_folder, int(cap.get(cv2.CAP_PROP_POS_FRAMES)), model_id) + + event.subscribe(callback) # 订阅绘制和保存作业 + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if skip_frames > 0: # 跳过指定帧数,设置skip_frames为0可取消跳帧 + skip_frames -= 1 # 倒数计时器 + continue + + # 使用 YOLO 进行推理 + results = model(frame) # 输入当前帧 + boxes = [] # 定义一个空列表,用于存放检测框 + + # 处理 YOLO 的输出 + for result in results: + for box in result.boxes: + class_id = int(box.cls) # 获取类别ID + if class_id in classes: # 检查类别是否在指定类别数组中 + x1, y1, x2, y2 = map(int, box.xyxy[0]) # 获取框坐标 + boxes.append((class_id, (x1, y1, x2, y2))) + + #如果检测到物体,则跳过30帧 + if boxes: + event.trigger(frame, boxes, save_folder) + skip_frames = 100 # 检测到物体后跳过30帧 + + cap.release() + +def process_video_threshold(video_path, model, save_folder, model_id, classes_str, event, threshold): + """ + 处理视频,并在指定区域绘制检测框,并在指定阈值触发事件 + 参数说明: + video_path: 视频路径(应当改为视频流??) + model: YOLO 模型 + save_folder: 保存帧的文件夹 + model_id: 模型ID + classes_str: 要检测的类别(字符串形式) + event: 自定义事件(CustomEvent0、CustomEvent1 等) + threshold: 触发事件的阈值 + """ + os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在) + classes = list(map(int, classes_str.split(','))) # 将逗号分隔的字符串转换为整数列表 + cap = cv2.VideoCapture(video_path) # 打开视频文件 + skip_frames = 0 # 初始化跳帧数 + + # 订阅绘制和保存作业 + event.subscribe(lambda frame, boxes, save_folder: draw_and_save(frame, boxes, save_folder, int(cap.get(cv2.CAP_PROP_POS_FRAMES)), model_id)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if skip_frames > 0: # 跳过指定帧数 + skip_frames -= 1 + continue + + # 使用 YOLO 进行推理 + results = model(frame, + conf=0.4, # 置信度阈值 + iou=0.5, # 交并比阈值 + ) # 输入当前帧 + boxes = [] # 定义一个空列表,用于存放检测框 + + # 处理 YOLO 的输出 + for result in results: + for box in result.boxes: + class_id = int(box.cls) # 获取类别ID + if class_id in classes: # 检查类别是否在指定类别数组中 + x1, y1, x2, y2 = map(int, box.xyxy[0]) # 获取框坐标 + boxes.append((class_id, (x1, y1, x2, y2))) + + # 计算区域内目标数量 + object_count = len(boxes) + + # 如果目标数量超过阈值,则触发事件 + if object_count > threshold: + event.trigger(frame, boxes, save_folder) + skip_frames = 100 # 跳过100帧 + + cap.release() +#------------------------------------------------------------------------------------------------------------------------------------------------------ +# 自定义事件类(用于订阅和触发作业,检测到物体时触发作业) +class CustomEvent0: + def __init__(self): + self._listeners = [] + + def subscribe(self, listener): + """订阅一个作业,当事件触发时会执行这个作业""" + self._listeners.append(listener) + + def trigger(self, frame, boxes, save_folder): + """触发事件,执行所有订阅的作业""" + for listener in self._listeners: + listener(frame, boxes, save_folder) + +#自定义事件类(用于订阅和触发作业,检测到物体大于一定数量时触发作业) +class CustomEvent1: + def __init__(self, threshold=5): + """初始化事件类,并设置触发阈值""" + self._listeners = [] + self.threshold = threshold # 设置触发作业的物体数量阈值 + + def subscribe(self, listener): + """订阅一个作业,当事件触发时会执行这个作业""" + self._listeners.append(listener) + + def trigger(self, frame, boxes, save_folder): + """当检测物体数量超过阈值时,触发事件""" + if len(boxes) >= self.threshold: # 检查是否超过阈值 + for listener in self._listeners: + listener(frame, boxes, save_folder) + + def in_area_trigger(self, frame, boxes, save_folder, class_id=None, x_min=100, y_min=100, x_max=500, y_max=500): + """检测屏幕范围内特定类别的物体数量是否超过阈值""" + count_in_area = 0 + for box in boxes: + obj_class_id, (x1, y1, x2, y2) = box + # 检查类别(如果指定)和坐标范围 + if (class_id is None or obj_class_id == class_id) and x1 >= x_min and y1 >= y_min and x2 <= x_max and y2 <= y_max: + count_in_area += 1 + + # 如果数量超过阈值,触发事件 + if count_in_area >= self.threshold: + for listener in self._listeners: + listener(frame, boxes, save_folder) + +#自定义事件类(用于订阅和触发作业,检测到物体小于一定数量时触发作业) +class CustomEvent2: + def __init__(self, threshold=5): + """初始化事件类,并设置触发阈值""" + self._listeners = [] + self.threshold = threshold # 设置触发作业的物体数量阈值 + + def subscribe(self, listener): + """订阅一个作业,当事件触发时会执行这个作业""" + self._listeners.append(listener) + + def trigger(self, frame, boxes, save_folder): + """当检测物体数量超过阈值时,触发事件""" + if len(boxes) >= self.threshold: # 检查是否超过阈值 + for listener in self._listeners: + listener(frame, boxes, save_folder) + +#----------------------------------------------------------------------------------------------------------------------------------------------- + +# 获取当前日期和时间精确到分钟作为run_name +current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M") +# 定义处理任务的函数 +def process_task(video_path, model_path, save_folder, classes_str, event_type, threshold=None, model_id=None): + """ + 统一处理任务的函数 + """ + model = YOLO(model_path) # 加载 YOLO 模型 + os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在) + + if threshold is not None: + # 使用阈值处理视频 + threading.Thread( + target=process_video_threshold, + args=(video_path, model, save_folder, model_id, classes_str, event_type, threshold) + ).start() + else: + # 不使用阈值处理视频 + threading.Thread( + target=process_video, + args=(video_path, model, save_folder, model_id, classes_str, event_type) + ).start() + +def Start_program(video_path, Start_String): + """ + 启动程序,根据 Start_String 启动特定的模型 + :param video_path: 视频路径 + :param Start_String: 启动模型的字符串 + """ + task_configurations = { + + ''' + "模型名称": { + "model_path": "模型路径", + "save_folder": "保存文件夹路径", + "classes": "检测物体的类别", + "event": 自定义事件类实例, Customevent0:if have aim /1: aim number more than one number /2: aim number less than one number + "threshold": 触发阈值(可选), + "model_id": 模型ID(可选) + } + ''' + + # 出现了人 + "have_peoples": { + "model_path": "yolo11n.pt", + "save_folder": f"output/{current_date}_People_in_the_area", + "classes": "0", + "event": CustomEvent0(), + "threshold": None, + "model_id": 1, + }, + #发现了人员聚集 + "many_peoples": { + "model_path": "yolo11n.pt", + "save_folder": f"output/{current_date}_Many_people", + "classes": "0", + "event": CustomEvent1(threshold=5), + "threshold": 5, + "model_id": 2, + }, + #发现了无安全帽人员 + "no_helmet": { + "model_path": "gdaq.pt", + "save_folder": f"output/{current_date}_Have_no_helmet", + "classes": "1", + "event": CustomEvent0(), + "threshold": None, + "model_id": 3, + }, + #无安全绳 + "no_safety_line": { + "model_path": "gdaq.pt", + "save_folder": f"output/{current_date}_Have_no_safety_line", + "classes": "3", + "event": CustomEvent0(), + "threshold": None, + "model_id": 4, + }, + "smoke": { + "model_path": "smoke.pt", + "save_folder": f"output/{current_date}_Have_smoke", + "classes": "1", + "event": CustomEvent0(), + "threshold": None, + "model_id": 5, + }, + #发现了烟雾 + "fire": { + "model_path": "smoke.pt", + "save_folder": f"output/{current_date}_Have_fire", + "classes": "0", + "event": CustomEvent0(), + "threshold": None, + "model_id": 6, + }, + #护栏破损 + "HULAN_POSUN": { + "model_path": "gdaq.pt", + "save_folder": f"output/{current_date}_HULAN_POSUN", + "classes": "6", + "event": CustomEvent0(), + "threshold": None, + "model_id": 7, + }, + } + # 将 Start_String 拆分为列表 + start_flags = Start_String.split(",") + + for flag in start_flags: + if flag in task_configurations: + config = task_configurations[flag] + process_task( + video_path, + config["model_path"], + config["save_folder"], + config["classes"], + config["event"], + config["threshold"], + config["model_id"], #we difine model_id to identify the model + ) + +if __name__ == '__main__': + # 使用线程处理视频 + video_path = r"D:\work\DJI_20241122093213_0043_V.mp4" # 替换为你的视频路径 + String = "have_peoples,many_peoples,no_helmet,no_safety_line" # 设置要启动的模型 + Start_program(video_path, String) # 启动程序 diff --git a/even/__pycache__/A_lot_of.cpython-312.pyc b/even/__pycache__/A_lot_of.cpython-312.pyc new file mode 100644 index 0000000..437fe9d Binary files /dev/null and b/even/__pycache__/A_lot_of.cpython-312.pyc differ diff --git a/even/__pycache__/colour_map.cpython-312.pyc b/even/__pycache__/colour_map.cpython-312.pyc new file mode 100644 index 0000000..da8c3ba Binary files /dev/null and b/even/__pycache__/colour_map.cpython-312.pyc differ diff --git a/even/__pycache__/even_program_video.cpython-312.pyc b/even/__pycache__/even_program_video.cpython-312.pyc new file mode 100644 index 0000000..baf25d7 Binary files /dev/null and b/even/__pycache__/even_program_video.cpython-312.pyc differ diff --git a/even/__pycache__/even_program_video_back.cpython-312.pyc b/even/__pycache__/even_program_video_back.cpython-312.pyc new file mode 100644 index 0000000..c7e1a37 Binary files /dev/null and b/even/__pycache__/even_program_video_back.cpython-312.pyc differ diff --git a/even/__pycache__/even_rule.cpython-312.pyc b/even/__pycache__/even_rule.cpython-312.pyc new file mode 100644 index 0000000..010f839 Binary files /dev/null and b/even/__pycache__/even_rule.cpython-312.pyc differ diff --git a/even/gdaq.pt b/even/gdaq.pt new file mode 100644 index 0000000..309fae0 Binary files /dev/null and b/even/gdaq.pt differ diff --git a/even/smoke.pt b/even/smoke.pt new file mode 100644 index 0000000..a7759d8 Binary files /dev/null and b/even/smoke.pt differ diff --git a/even/yolo11n.pt b/even/yolo11n.pt new file mode 100644 index 0000000..45b273b Binary files /dev/null and b/even/yolo11n.pt differ diff --git a/even/zhuanhuan1.py b/even/zhuanhuan1.py new file mode 100644 index 0000000..a86caab --- /dev/null +++ b/even/zhuanhuan1.py @@ -0,0 +1,167 @@ +import numpy as np +import cv2 +from pyproj import Transformer +import math +from PIL import Image +from pylab import * +from dataclasses import dataclass + +def rotate_points_around_axis(points, axis, angle): + # 将轴转换为单位向量 + axis =axis.squeeze() + axis = axis / np.linalg.norm(axis) + + # 将角度转换为弧度 + theta = np.radians(angle) + + # 计算旋转矩阵 + K = np.array([[0, -axis[2], axis[1]], + [axis[2], 0, -axis[0]], + [-axis[1], axis[0], 0]]) + + I = np.eye(3) # 单位矩阵 + R = I + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K) + + # 将点平移到旋转轴的起点 + + # 旋转多个点 + rotated_points = np.dot(R.T,points.T) # 矩阵乘法 + # 再平移回去 + + return rotated_points + +def rotate_points(points, center, theta): + """ + 批量旋转二维点 + :param points: 输入的二维点矩阵,形状为 (N, 2) + :param center: 旋转中心 (x_c, y_c) + :param theta: 旋转角度(弧度) + :return: 旋转后的点矩阵,形状为 (N, 2) + """ + # 平移到旋转中心 + translated = points - center # 平移,使得中心在原点 + rotation_matrix = np.array([[np.cos(theta), np.sin(theta)], + [-np.sin(theta), np.cos(theta)]]) + + # 应用旋转 + rotated = translated @ rotation_matrix.T # 矩阵乘法 + return rotated + center # 将旋转后的点移回原位置 +def compute_rotation_matrix(pitch, roll1, yaw): + c2w_x = np.array([[1, 0, 0], + [0, math.cos(pitch), -math.sin(pitch)], + [0, math.sin(pitch), math.cos(pitch)]], dtype=np.float32) + c2w_y = np.array([[math.cos(roll1), 0, math.sin(roll1)], + [0, 1, 0], + [-math.sin(roll1), 0, math.cos(roll1)]], dtype=np.float32) + c2w_z = np.array([[math.cos(yaw), -math.sin(yaw), 0], + [math.sin(yaw), math.cos(yaw), 0], + [0, 0, 1]], dtype=np.float32) + c2w = c2w_x @ c2w_y @ c2w_z + # 按 Z-Y-X 顺序组合旋转矩阵 + y=np.array([[0],[0],[0]]) + c2w1 = np.hstack([c2w,y]) + new_row = np.array([[0, 0, 0, 1]]) + c2w1 = np.vstack([c2w1, new_row]) + return c2w, c2w1 +def create_transformation_matrix(tvec): + transform_matrix = np.eye(4) # 创建单位矩阵 + transform_matrix[:3, 3] = tvec.flatten() # 平移向量 + return transform_matrix + +def dms2dd(x,y,z): + return x+y/60+z/3600 + + +@dataclass +class ProjectTrans: + ''' + 计算坐标 + ''' + def __init__(self, f_mm, sensor_width_mm, sensor_height_mm, image_width_px, image_height_px, k1, k2, k3, k4, k5): + self.image_points = None + self.tvec = None + self.rvec = None + self.object_point = None + self.f_mm = f_mm # 物理焦距 (mm) + self.sensor_width_mm= sensor_width_mm # 传感器宽度 (mm) + self.sensor_height_mm=sensor_height_mm # 传感器高度 (mm) + self.image_width_px= image_width_px # 图像宽度 (px) + self.image_height_px= image_height_px # 图像高度 (px) + self.Ji = np.array([[k1], [k2], [k3], [k4],[k5]], dtype=np.float32) + self.camera_matrix = None + + def apply_transformation(self, c2w, transform_matrix): + # 齐次化 + homogeneous_points = np.hstack([self.object_point, np.ones((self.object_point.shape[0], 1))]) + # 旋转平移 + transformed_points = transform_matrix @ homogeneous_points.T + transformed_points = c2w.T @ transformed_points + return transformed_points[:3, :] # 返回前3列(x, y, z坐标) + + def compute_camera_matrix(self): #返回内参矩阵 + """计算并返回相机内参矩阵.""" + fx = floor((self.f_mm * self.image_width_px) / self.sensor_width_mm) + fy = floor((self.f_mm * self.image_height_px) / self.sensor_height_mm) + cx = self.image_width_px / 2 + cy = self.image_height_px / 2 + fx=3725 + fy=3725 + cx= 2640 + cy=1978 + # 构造内参矩阵 + self.camera_matrix = np.array([ + [fx, 0, cx], + [0, fy, cy], + [0, 0, 1] + ], dtype=np.float32) + # def transform_Y(self): + # num_rows = self.image_points.shape[0] + # new_column = np.ones((num_rows, 1)) + # self.image_points[:,0] +=80 + # self.image_points[:,1] +=100 + # self.image_points = np.hstack([self.image_points, new_column]) + # transform = np.array([[math.cos(radians(-84.8)), -math.sin(radians(-84.8)), 2728*(1-math.cos(radians(-84.8)))+1980.*math.sin(radians(-84.8))], + # [math.sin(radians(-84.8)), math.cos(radians(-84.8)), 1980*(1-math.cos(radians(-84.8)))-2728*math.sin(radians(-84.8))], + # [0, 0, 1]], dtype=np.float32) + # self.image_points = (transform @ self.image_points.T).T + def projectPoints(self, object_point, tvec, rvec): + self.object_point = object_point + self.tvec=tvec + self.rvec=np.radians(rvec) + # transformer = Transformer.from_crs("epsg:4326", "epsg:4544") + # for index, point in enumerate(object_point): + # lat,lon,z = point + # self.object_poin t[index,1],self.object_point[index,0] = transformer.transform(lat, lon) + # self.tvec[1] ,self.tvec[0] = transformer.transform(tvec[0], tvec[1]) + self.tvec = -self.tvec + # c2w = compute_rotation_matrix(np.pi - 0.1 / 180 * np.pi, 0.85 / 180 * np.pi, 0) # 计算旋转矩阵 + # transform_matrix = create_transformation_matrix(self.tvec) # 创建平移矩阵 + # self.object_point = self.apply_transformation(c2w, transform_matrix) # 旋转平移 + # self.object_point = rotate_points_around_axis(self.object_point, np.array([0.0148148, -0.00174533, 0.999901]), -rvec[1]) + c2w1, c2w = compute_rotation_matrix(np.pi-0.1/180*np.pi, -1/180*np.pi, 1/180*np.pi) #计算旋转矩阵 + transform_matrix = create_transformation_matrix(self.tvec) #创建平移矩阵 + self.object_point = self.apply_transformation(c2w, transform_matrix) #旋转平移 + axis = c2w1.T @ np.array([[0],[0],[1]],dtype=np.float32) + self.object_point = rotate_points_around_axis(self.object_point.T,axis.T,- rvec[1]) + self.object_point = rotate_points_around_axis(self.object_point.T, np.array((0.9999,-0.00174533,-0.0148)), (rvec[0]+90)) + self.object_point = self.object_point.T + self.compute_camera_matrix() #计算相机内参矩 + BUP=np.array([0,0,0],dtype=np.float32) + self.rvec = np.array([0,0,0],dtype=np.float32) + self.image_points, _ = cv2.projectPoints(self.object_point, self.rvec, BUP, self.camera_matrix, self.Ji) + self.image_points = self.image_points.squeeze() + # self.transform_Y() + # self.image_points = rotate_points(self.image_points, np.array([2640, 1978]), rvec[1]/180*np.pi) + return self.image_points + + def plot(self): + x_coords=self.image_points[:,0] + y_coords=self.image_points[:,1] + print(self.image_points,self.rvec) + # 打开图像并绘制投影点 + im = Image.open('DJI_正.jpeg') + imshow(im) + # plot(x_coords, y_coords, 'r*') + plot(2647.02,1969.38,'b*')# 绘制红色星标 + show() + cv2.imwrite('result1.jp eg', im) \ No newline at end of file