20250709
This commit is contained in:
commit
45fde807d8
1
ai2/.codemap/main-panel.json
Normal file
1
ai2/.codemap/main-panel.json
Normal file
@ -0,0 +1 @@
|
||||
[]
|
BIN
ai2/GDCL.pt
Normal file
BIN
ai2/GDCL.pt
Normal file
Binary file not shown.
BIN
ai2/HWRC.pt
Normal file
BIN
ai2/HWRC.pt
Normal file
Binary file not shown.
BIN
ai2/__pycache__/ai_open_api.cpython-312.pyc
Normal file
BIN
ai2/__pycache__/ai_open_api.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/color.cpython-312.pyc
Normal file
BIN
ai2/__pycache__/color.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/cv_video.cpython-311.pyc
Normal file
BIN
ai2/__pycache__/cv_video.cpython-311.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/cv_video.cpython-312.pyc
Normal file
BIN
ai2/__pycache__/cv_video.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/minio_oss.cpython-311.pyc
Normal file
BIN
ai2/__pycache__/minio_oss.cpython-311.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/minio_oss.cpython-312.pyc
Normal file
BIN
ai2/__pycache__/minio_oss.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/pgadmin_helper.cpython-311.pyc
Normal file
BIN
ai2/__pycache__/pgadmin_helper.cpython-311.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/pgadmin_helper.cpython-312.pyc
Normal file
BIN
ai2/__pycache__/pgadmin_helper.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/yolo_findaim.cpython-311.pyc
Normal file
BIN
ai2/__pycache__/yolo_findaim.cpython-311.pyc
Normal file
Binary file not shown.
BIN
ai2/__pycache__/yolo_findaim.cpython-312.pyc
Normal file
BIN
ai2/__pycache__/yolo_findaim.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai2/best.pt
Normal file
BIN
ai2/best.pt
Normal file
Binary file not shown.
593
ai2/cv_video copy.py
Normal file
593
ai2/cv_video copy.py
Normal file
@ -0,0 +1,593 @@
|
||||
import cv2
|
||||
import subprocess
|
||||
from threading import Thread, Lock, Event
|
||||
import time
|
||||
import queue
|
||||
import numpy as np
|
||||
import datetime
|
||||
import os
|
||||
from ultralytics import YOLO # 导入 Ultralytics YOLO 模型
|
||||
|
||||
# 全局变量
|
||||
ifAI = {'status': False}
|
||||
deskLock = Lock()
|
||||
frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小
|
||||
processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列
|
||||
stop_event = Event()
|
||||
|
||||
def setIfAI(pb1):
|
||||
deskLock.acquire()
|
||||
ifAI['status'] = pb1
|
||||
deskLock.release()
|
||||
|
||||
def getIfAI():
|
||||
return ifAI['status']
|
||||
|
||||
def stopAIVideo():
|
||||
print("正在停止AI视频处理...")
|
||||
setIfAI(False)
|
||||
stop_event.set()
|
||||
|
||||
# 等待足够长的时间确保资源释放
|
||||
wait_count = 0
|
||||
max_wait = 5 # 减少最大等待时间到5秒
|
||||
|
||||
while stop_event.is_set() and wait_count < max_wait:
|
||||
time.sleep(0.5)
|
||||
wait_count += 1
|
||||
|
||||
if wait_count >= max_wait:
|
||||
print("警告: 停止AI视频处理超时,强制终止")
|
||||
# 不使用_thread._interrupt_main(),改用其他方式强制终止
|
||||
try:
|
||||
# 尝试终止可能运行的进程
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
# 查找并终止可能的FFmpeg进程
|
||||
current_process = psutil.Process(os.getpid())
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
print(f"正在终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print("AI视频处理已停止")
|
||||
|
||||
def startAIVideo(video_path, output_url, m1, cls, confidence):
|
||||
if ifAI['status']:
|
||||
stopAIVideo()
|
||||
time.sleep(1)
|
||||
stop_event.clear()
|
||||
thread = Thread(target=startAIVideo2,
|
||||
args=(video_path, output_url, m1, cls, confidence))
|
||||
# cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence))
|
||||
# cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出
|
||||
thread.daemon = True # 守护线程,主程序退出时线程也会退出
|
||||
|
||||
|
||||
thread.start()
|
||||
# cls2_thread.start()
|
||||
|
||||
def read_frames(cap, frame_queue):
|
||||
"""优化的帧读取线程"""
|
||||
frame_count = 0
|
||||
last_time = time.time()
|
||||
last_fps_time = time.time()
|
||||
# 减小目标帧间隔时间,提高读取帧率
|
||||
target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间(提高到60fps)
|
||||
|
||||
# 添加连接断开检测
|
||||
connection_error_count = 0
|
||||
max_connection_errors = 10 # 最多允许连续10次连接错误
|
||||
last_successful_read = time.time()
|
||||
max_read_wait = 30.0 # 30秒无法读取则认为连接断开
|
||||
|
||||
# 预先丢弃几帧,确保从新帧开始处理
|
||||
for _ in range(5):
|
||||
cap.grab()
|
||||
|
||||
while not stop_event.is_set():
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - last_time
|
||||
|
||||
# 检查是否长时间无法读取帧
|
||||
if current_time - last_successful_read > max_read_wait:
|
||||
print(f"警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开")
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
# 帧率控制,但更积极地读取
|
||||
if elapsed_time < target_time_per_frame:
|
||||
time.sleep(target_time_per_frame - elapsed_time)
|
||||
continue
|
||||
|
||||
# 当队列快满时,跳过一些帧以避免延迟累积
|
||||
if frame_queue.qsize() > frame_queue.maxsize * 0.8:
|
||||
# 跳过一些帧
|
||||
cap.grab()
|
||||
last_time = time.time()
|
||||
continue
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("拉流错误:无法读取帧")
|
||||
connection_error_count += 1
|
||||
if connection_error_count >= max_connection_errors:
|
||||
print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...")
|
||||
stop_event.set()
|
||||
break
|
||||
time.sleep(0.5) # 短暂等待后重试
|
||||
continue
|
||||
|
||||
# 成功读取了帧,重置错误计数
|
||||
connection_error_count = 0
|
||||
last_successful_read = time.time()
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 60 == 0: # 每60帧计算一次FPS
|
||||
current_fps_time = time.time()
|
||||
fps = 60 / (current_fps_time - last_fps_time)
|
||||
print(f"拉流FPS: {fps:.2f}")
|
||||
last_fps_time = current_fps_time
|
||||
|
||||
last_time = time.time()
|
||||
frame_queue.put((frame, time.time())) # 添加时间戳
|
||||
|
||||
def process_frames(frame_queue, processed_frame_queue, ov_model, cls, confidence):
|
||||
"""处理帧的线程,添加帧率控制"""
|
||||
error_count = 0 # 添加错误计数器
|
||||
max_errors = 5 # 最大容许错误次数
|
||||
frame_count = 0
|
||||
last_process_time = time.time()
|
||||
process_times = [] # 用于计算平均处理时间
|
||||
|
||||
# 设置YOLO模型配置,提高性能
|
||||
ov_model.conf = confidence # 设置置信度阈值
|
||||
|
||||
# 优化推理性能
|
||||
try:
|
||||
# 导入torch库
|
||||
import torch
|
||||
# 尝试启用ONNX Runtime加速
|
||||
ov_model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
# 调整批处理大小为1,减少内存占用
|
||||
if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
|
||||
ov_model.args.batch = 1
|
||||
# 使用half精度,提高性能
|
||||
if torch.cuda.is_available() and hasattr(ov_model, 'model'):
|
||||
try:
|
||||
ov_model.model = ov_model.model.half()
|
||||
except Exception as half_err:
|
||||
print(f"半精度转换失败: {half_err}")
|
||||
except Exception as e:
|
||||
print(f"模型优化配置警告: {e}")
|
||||
|
||||
# 缓存先前的检测结果,用于提高稳定性
|
||||
last_results = None
|
||||
skip_counter = 0
|
||||
max_skip = 2 # 最多跳过几帧不处理
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8:
|
||||
# 如果输出队列接近满,等待一段时间
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
frame, timestamp = frame_queue.get(timeout=0.2)
|
||||
|
||||
# 处理延迟过大的帧
|
||||
if time.time() - timestamp > 0.3: # 减少延迟阈值
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 间隔采样,每n帧处理一次,减少计算量
|
||||
if skip_counter > 0 and last_results is not None:
|
||||
skip_counter -= 1
|
||||
# 使用上次的检测结果
|
||||
annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
|
||||
processed_frame_queue.put((annotated_frame, timestamp))
|
||||
continue
|
||||
|
||||
process_start = time.time()
|
||||
|
||||
# 动态调整处理尺寸,根据队列积压情况
|
||||
resize_scale = 1.0
|
||||
if frame_queue.qsize() > frame_queue.maxsize * 0.7:
|
||||
resize_scale = 0.4 # 高负载时大幅降低分辨率
|
||||
elif frame_queue.qsize() > frame_queue.maxsize * 0.5:
|
||||
resize_scale = 0.6 # 中等负载时适度降低分辨率
|
||||
elif frame_queue.qsize() > frame_queue.maxsize * 0.3:
|
||||
resize_scale = 0.8 # 轻微负载时轻微降低分辨率
|
||||
|
||||
# 调整图像尺寸以加快处理
|
||||
if resize_scale < 1.0:
|
||||
process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale)
|
||||
else:
|
||||
process_frame = frame
|
||||
|
||||
# 执行推理
|
||||
try:
|
||||
results = ov_model(process_frame, classes=cls, show=False)
|
||||
last_results = results[0] # 保存检测结果用于后续帧
|
||||
|
||||
# 如果尺寸调整过,需要将结果转换回原始尺寸
|
||||
if resize_scale < 1.0:
|
||||
# 绘制检测框
|
||||
annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5),
|
||||
(frame.shape[1], frame.shape[0]))
|
||||
else:
|
||||
annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5)
|
||||
|
||||
# 在负载高时启用跳帧处理
|
||||
if frame_queue.qsize() > frame_queue.maxsize * 0.5:
|
||||
skip_counter = max_skip
|
||||
except Exception as infer_err:
|
||||
print(f"推理错误: {infer_err}")
|
||||
if last_results is not None:
|
||||
# 使用上次的结果
|
||||
annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
|
||||
else:
|
||||
# 如果没有上次的结果,简单返回原始帧
|
||||
annotated_frame = frame.copy()
|
||||
|
||||
process_end = time.time()
|
||||
process_times.append(process_end - process_start)
|
||||
if len(process_times) > 30:
|
||||
process_times.pop(0)
|
||||
|
||||
if frame_count % 30 == 0:
|
||||
avg_process_time = sum(process_times) / len(process_times)
|
||||
fps = 1.0 / avg_process_time if avg_process_time > 0 else 0
|
||||
print(f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time*1000:.2f}ms, 队列大小: {frame_queue.qsize()}, 缩放比例: {resize_scale:.2f}")
|
||||
|
||||
processed_frame_queue.put((annotated_frame, timestamp))
|
||||
error_count = 0 # 成功处理后重置错误计数
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"处理帧错误: {e}")
|
||||
if error_count >= max_errors:
|
||||
print(f"连续处理错误达到{max_errors}次,正在停止处理...")
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
def write_frames(processed_frame_queue, pipe, size):
|
||||
"""写入帧的线程,添加平滑处理"""
|
||||
last_write_time = time.time()
|
||||
target_time_per_frame = 1.0 / 30.0 # 30fps
|
||||
pipe_error_count = 0 # 添加错误计数
|
||||
max_pipe_errors = 3 # 最大容忍错误数
|
||||
frame_count = 0
|
||||
last_fps_time = time.time()
|
||||
skipped_frames = 0
|
||||
|
||||
# 使用队列存储最近几帧,用于在需要时进行插值
|
||||
recent_frames = []
|
||||
max_recent_frames = 5 # 增加缓存帧数量,提高平滑性
|
||||
|
||||
# 使用双缓冲机制提高写入速度
|
||||
buffer1 = bytearray(size[0] * size[1] * 3)
|
||||
buffer2 = bytearray(size[0] * size[1] * 3)
|
||||
current_buffer = buffer1
|
||||
|
||||
# 帧率控制参数
|
||||
min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔
|
||||
max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
# 获取处理后的帧,超时时间较短以便更平滑地处理
|
||||
frame, timestamp = processed_frame_queue.get(timeout=0.05)
|
||||
|
||||
# 存储最近的帧用于插值
|
||||
recent_frames.append(frame)
|
||||
if len(recent_frames) > max_recent_frames:
|
||||
recent_frames.pop(0)
|
||||
|
||||
current_time = time.time()
|
||||
elapsed = current_time - last_write_time
|
||||
|
||||
# 如果两帧间隔太短,考虑合并或跳过
|
||||
if elapsed < min_frame_interval and len(recent_frames) >= 2:
|
||||
skipped_frames += 1
|
||||
continue
|
||||
|
||||
# 如果两帧间隔太长,进行插值平滑
|
||||
if elapsed > max_frame_interval and len(recent_frames) >= 2:
|
||||
# 计算需要插入的帧数
|
||||
frames_to_insert = min(int(elapsed / target_time_per_frame), 3)
|
||||
|
||||
for i in range(frames_to_insert):
|
||||
# 创建插值帧
|
||||
weight = (i + 1) / (frames_to_insert + 1)
|
||||
interpolated_frame = cv2.addWeighted(recent_frames[-2], 1-weight, recent_frames[-1], weight, 0)
|
||||
|
||||
# 切换双缓冲
|
||||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||||
|
||||
# 高效调整大小并写入
|
||||
interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
|
||||
img_bytes = interpolated_resized.tobytes()
|
||||
|
||||
# 写入管道
|
||||
pipe.stdin.write(img_bytes)
|
||||
pipe.stdin.flush()
|
||||
|
||||
# 切换双缓冲
|
||||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||||
|
||||
# 正常写入当前帧 - 使用高效的调整大小方法
|
||||
resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR)
|
||||
img_bytes = resized_frame.tobytes()
|
||||
pipe.stdin.write(img_bytes)
|
||||
pipe.stdin.flush()
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 30 == 0:
|
||||
current_fps_time = time.time()
|
||||
fps = 30 / (current_fps_time - last_fps_time)
|
||||
print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}")
|
||||
last_fps_time = current_fps_time
|
||||
skipped_frames = 0
|
||||
|
||||
last_write_time = time.time()
|
||||
pipe_error_count = 0 # 成功写入后重置错误计数
|
||||
|
||||
except queue.Empty:
|
||||
# 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅
|
||||
if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame:
|
||||
try:
|
||||
# 创建插值帧
|
||||
interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0)
|
||||
|
||||
# 切换双缓冲
|
||||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||||
|
||||
resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
|
||||
img_bytes = resized_frame.tobytes()
|
||||
pipe.stdin.write(img_bytes)
|
||||
pipe.stdin.flush()
|
||||
last_write_time = time.time()
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"写入帧错误: {e}")
|
||||
pipe_error_count += 1
|
||||
if pipe_error_count >= max_pipe_errors:
|
||||
print("FFmpeg管道错误过多,正在终止进程...")
|
||||
stop_event.set() # 主动结束所有线程
|
||||
break
|
||||
|
||||
def startAIVideo2(video_path, output_url, m1, cls, confidence):
|
||||
rtmp = output_url
|
||||
setIfAI(True)
|
||||
|
||||
cap = None
|
||||
pipe = None
|
||||
read_thread = None
|
||||
process_thread = None
|
||||
write_thread = None
|
||||
ov_model = None
|
||||
|
||||
try:
|
||||
import os, cv2, torch, time, queue, subprocess
|
||||
import numpy as np
|
||||
from threading import Thread, Event
|
||||
from ultralytics import YOLO
|
||||
|
||||
global frame_queue, processed_frame_queue, stop_event
|
||||
stop_event = Event()
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = "4"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print("预加载YOLO模型...")
|
||||
model_params = {}
|
||||
try:
|
||||
test_model = YOLO(m1)
|
||||
if hasattr(test_model, "task"):
|
||||
model_params["task"] = "detect"
|
||||
if torch.cuda.is_available():
|
||||
model_params["half"] = True
|
||||
import inspect
|
||||
if "verbose" in inspect.signature(YOLO.__init__).parameters:
|
||||
model_params["verbose"] = False
|
||||
except Exception as e:
|
||||
print(f"参数检测失败,将使用默认参数: {e}")
|
||||
model_params = {}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 3:
|
||||
try:
|
||||
ov_model = YOLO(m1, **model_params)
|
||||
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
|
||||
for _ in range(3):
|
||||
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
|
||||
print("YOLO模型加载成功并预热完成")
|
||||
break
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
print(f"YOLO模型加载失败(尝试 {retry_count}/3): {e}")
|
||||
if "unexpected keyword" in str(e):
|
||||
param = str(e).split("'")[-2]
|
||||
if param in model_params:
|
||||
print(f"移除不支持的参数: {param}")
|
||||
del model_params[param]
|
||||
time.sleep(2)
|
||||
|
||||
if ov_model is None:
|
||||
raise Exception("无法加载YOLO模型")
|
||||
|
||||
ov_model.conf = confidence
|
||||
|
||||
print(f"正在连接视频流: {video_path}")
|
||||
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 5)
|
||||
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise Exception(f"无法打开视频流: {video_path}")
|
||||
|
||||
try:
|
||||
cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0)
|
||||
except Exception as e:
|
||||
print(f"无法设置自动曝光参数: {e}")
|
||||
|
||||
frame_queue = queue.Queue(maxsize=80)
|
||||
processed_frame_queue = queue.Queue(maxsize=40)
|
||||
|
||||
size = (1920, 1080)
|
||||
sizeStr = f"{size[0]}x{size[1]}"
|
||||
command = ['ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo',
|
||||
'-pix_fmt', 'bgr24', '-s', sizeStr, '-r', '30', '-i', '-',
|
||||
'-c:v', 'h264', '-pix_fmt', 'yuv420p',
|
||||
'-preset', 'ultrafast', '-tune', 'zerolatency',
|
||||
'-f', 'flv', '-g', '30', '-bufsize', '4000k',
|
||||
'-maxrate', '4000k', '-b:v', '2500k', '-vsync', '1',
|
||||
'-threads', '4', rtmp]
|
||||
|
||||
print(f"启动FFmpeg推流到: {rtmp}")
|
||||
pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
bufsize=10*1024*1024)
|
||||
|
||||
def monitor_ffmpeg_output(pipe):
|
||||
while not stop_event.is_set():
|
||||
line = pipe.stderr.readline().decode('utf-8', errors='ignore')
|
||||
if line and ("error" in line.lower()):
|
||||
print(f"FFmpeg错误: {line.strip()}")
|
||||
if "Cannot open connection" in line:
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
Thread(target=monitor_ffmpeg_output, args=(pipe,), daemon=True).start()
|
||||
|
||||
def read_frames(cap, frame_queue):
|
||||
while not stop_event.is_set():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("读取失败")
|
||||
break
|
||||
timestamp = time.time()
|
||||
try:
|
||||
frame_queue.put((timestamp, frame), timeout=1)
|
||||
except queue.Full:
|
||||
print("帧队列满,跳帧")
|
||||
|
||||
def process_frames(frame_queue, processed_frame_queue, model, cls, confidence):
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
timestamp, frame = frame_queue.get(timeout=1)
|
||||
result = model.predict(source=frame, classes=cls, conf=confidence, verbose=False)
|
||||
processed = result[0].plot()
|
||||
processed_frame_queue.put((timestamp, processed), timeout=1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"处理帧错误: {e}")
|
||||
|
||||
def write_frames(processed_frame_queue, pipe, size):
|
||||
last_timestamp = 0
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
timestamp, frame = processed_frame_queue.get(timeout=1)
|
||||
if timestamp < last_timestamp:
|
||||
print(f"跳过闪回帧 {timestamp} < {last_timestamp}")
|
||||
continue
|
||||
last_timestamp = timestamp
|
||||
frame = cv2.resize(frame, size)
|
||||
pipe.stdin.write(frame.tobytes())
|
||||
except Exception as e:
|
||||
print(f"写入帧错误: {e}")
|
||||
break
|
||||
|
||||
read_thread = Thread(target=read_frames, args=(cap, frame_queue), daemon=True, name="ReadThread")
|
||||
process_thread = Thread(target=process_frames, args=(frame_queue, processed_frame_queue, ov_model, cls, confidence), daemon=True, name="ProcessThread")
|
||||
write_thread = Thread(target=write_frames, args=(processed_frame_queue, pipe, size), daemon=True, name="WriteThread")
|
||||
|
||||
print("开始推流处理...")
|
||||
read_thread.start()
|
||||
process_thread.start()
|
||||
write_thread.start()
|
||||
|
||||
last_check = time.time()
|
||||
while getIfAI() and not stop_event.is_set():
|
||||
if not all([t.is_alive() for t in [read_thread, process_thread, write_thread]]):
|
||||
print("检测到线程停止,退出")
|
||||
stop_event.set()
|
||||
break
|
||||
if pipe.poll() is not None:
|
||||
print("FFmpeg退出")
|
||||
stop_event.set()
|
||||
break
|
||||
if time.time() - last_check > 30:
|
||||
print(f"输入队列: {frame_queue.qsize()}/{frame_queue.maxsize} | 输出队列: {processed_frame_queue.qsize()}/{processed_frame_queue.maxsize}")
|
||||
last_check = time.time()
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
print("清理资源...")
|
||||
stop_event.set()
|
||||
setIfAI(False)
|
||||
|
||||
for t in [read_thread, process_thread, write_thread]:
|
||||
if t and t.is_alive():
|
||||
t.join(timeout=2)
|
||||
|
||||
try:
|
||||
if cap: cap.release()
|
||||
if pipe:
|
||||
try:
|
||||
import signal
|
||||
os.kill(pipe.pid, signal.SIGTERM)
|
||||
except: pass
|
||||
pipe.stdin.close()
|
||||
pipe.terminate()
|
||||
try:
|
||||
pipe.wait(timeout=2)
|
||||
except:
|
||||
pipe.kill()
|
||||
except Exception as e:
|
||||
print(f"释放资源时出错: {e}")
|
||||
|
||||
try:
|
||||
cv2.destroyAllWindows()
|
||||
except:
|
||||
pass
|
||||
print("资源释放完毕")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sn = "1581F6QAD243C00BP71E"
|
||||
video_path = f"rtmp://222.212.85.86:1935/live/{sn}"
|
||||
# FFmpeg 推流地址
|
||||
rtmp = f"rtmp://222.212.85.86:1935/live/{sn}ai"
|
||||
|
||||
try:
|
||||
startAIVideo2(video_path, rtmp, "best.pt", [0, 1, 2, 3, 4],0.4)
|
||||
except KeyboardInterrupt:
|
||||
print("程序被用户中断")
|
||||
stopAIVideo()
|
||||
except Exception as e:
|
||||
print(f"程序异常: {e}")
|
||||
|
603
ai2/cv_video.py
Normal file
603
ai2/cv_video.py
Normal file
@ -0,0 +1,603 @@
|
||||
from threading import Thread, Lock, Event
|
||||
import time
|
||||
import queue
|
||||
from ultralytics import YOLO # 导入 Ultralytics YOLO 模型
|
||||
import os, cv2, torch, time, queue, subprocess
|
||||
import numpy as np
|
||||
|
||||
# 全局变量
|
||||
ifAI = {'status': False}
|
||||
deskLock = Lock()
|
||||
frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小
|
||||
processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列
|
||||
stop_event = Event()
|
||||
|
||||
def setIfAI(pb1):
|
||||
deskLock.acquire()
|
||||
ifAI['status'] = pb1
|
||||
deskLock.release()
|
||||
|
||||
def getIfAI():
|
||||
return ifAI['status']
|
||||
|
||||
def stopAIVideo():
|
||||
print("正在停止AI视频处理...")
|
||||
setIfAI(False)
|
||||
stop_event.set()
|
||||
|
||||
# 等待足够长的时间确保资源释放
|
||||
wait_count = 0
|
||||
max_wait = 5 # 减少最大等待时间到5秒
|
||||
|
||||
while stop_event.is_set() and wait_count < max_wait:
|
||||
time.sleep(0.5)
|
||||
wait_count += 1
|
||||
|
||||
if wait_count >= max_wait:
|
||||
print("警告: 停止AI视频处理超时,强制终止")
|
||||
# 不使用_thread._interrupt_main(),改用其他方式强制终止
|
||||
try:
|
||||
# 尝试终止可能运行的进程
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
# 查找并终止可能的FFmpeg进程
|
||||
current_process = psutil.Process(os.getpid())
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
print(f"正在终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print("AI视频处理已停止")
|
||||
|
||||
def startAIVideo(video_path, output_url, m1, cls, confidence):
|
||||
if ifAI['status']:
|
||||
stopAIVideo()
|
||||
time.sleep(1)
|
||||
stop_event.clear()
|
||||
thread = Thread(target=startAIVideo2,
|
||||
args=(video_path, output_url, m1, cls, confidence))
|
||||
# cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence))
|
||||
# cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出
|
||||
thread.daemon = True # 守护线程,主程序退出时线程也会退出
|
||||
|
||||
|
||||
thread.start()
|
||||
# cls2_thread.start()
|
||||
|
||||
def read_frames(cap, frame_queue):
|
||||
"""优化的帧读取线程"""
|
||||
frame_count = 0
|
||||
last_time = time.time()
|
||||
last_fps_time = time.time()
|
||||
# 减小目标帧间隔时间,提高读取帧率
|
||||
target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间(提高到60fps)
|
||||
|
||||
# 添加连接断开检测
|
||||
connection_error_count = 0
|
||||
max_connection_errors = 10 # 最多允许连续10次连接错误
|
||||
last_successful_read = time.time()
|
||||
max_read_wait = 30.0 # 30秒无法读取则认为连接断开
|
||||
|
||||
# 预先丢弃几帧,确保从新帧开始处理
|
||||
for _ in range(5):
|
||||
cap.grab()
|
||||
|
||||
while not stop_event.is_set():
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - last_time
|
||||
|
||||
# 检查是否长时间无法读取帧
|
||||
if current_time - last_successful_read > max_read_wait:
|
||||
print(f"警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开")
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
# 帧率控制,但更积极地读取
|
||||
if elapsed_time < target_time_per_frame:
|
||||
time.sleep(target_time_per_frame - elapsed_time)
|
||||
continue
|
||||
|
||||
# 当队列快满时,跳过一些帧以避免延迟累积
|
||||
if frame_queue.qsize() > frame_queue.maxsize * 0.8:
|
||||
# 跳过一些帧
|
||||
cap.grab()
|
||||
last_time = time.time()
|
||||
continue
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("拉流错误:无法读取帧")
|
||||
connection_error_count += 1
|
||||
if connection_error_count >= max_connection_errors:
|
||||
print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...")
|
||||
stop_event.set()
|
||||
break
|
||||
time.sleep(0.5) # 短暂等待后重试
|
||||
continue
|
||||
|
||||
# 成功读取了帧,重置错误计数
|
||||
connection_error_count = 0
|
||||
last_successful_read = time.time()
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 60 == 0: # 每60帧计算一次FPS
|
||||
current_fps_time = time.time()
|
||||
fps = 60 / (current_fps_time - last_fps_time)
|
||||
print(f"拉流FPS: {fps:.2f}")
|
||||
last_fps_time = current_fps_time
|
||||
|
||||
last_time = time.time()
|
||||
frame_queue.put((frame, time.time())) # 添加时间戳
|
||||
|
||||
def process_frames(frame_queue, processed_frame_queue, ov_model, cls, confidence):
|
||||
"""处理帧的线程,添加帧率控制"""
|
||||
error_count = 0 # 添加错误计数器
|
||||
max_errors = 5 # 最大容许错误次数
|
||||
frame_count = 0
|
||||
process_times = [] # 用于计算平均处理时间
|
||||
|
||||
# 设置YOLO模型配置,提高性能
|
||||
ov_model.conf = confidence # 设置置信度阈值
|
||||
|
||||
# 优化推理性能
|
||||
try:
|
||||
# 导入torch库
|
||||
import torch
|
||||
# 尝试启用ONNX Runtime加速
|
||||
ov_model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
# 调整批处理大小为1,减少内存占用
|
||||
if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
|
||||
ov_model.args.batch = 1
|
||||
# 使用half精度,提高性能
|
||||
if torch.cuda.is_available() and hasattr(ov_model, 'model'):
|
||||
try:
|
||||
ov_model.model = ov_model.model.half()
|
||||
except Exception as half_err:
|
||||
print(f"半精度转换失败: {half_err}")
|
||||
except Exception as e:
|
||||
print(f"模型优化配置警告: {e}")
|
||||
|
||||
# 缓存先前的检测结果,用于提高稳定性
|
||||
last_results = None
|
||||
skip_counter = 0
|
||||
max_skip = 2 # 最多跳过几帧不处理
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8:
|
||||
# 如果输出队列接近满,等待一段时间
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
frame, timestamp = frame_queue.get(timeout=0.2)
|
||||
|
||||
# 处理延迟过大的帧
|
||||
if time.time() - timestamp > 0.3: # 减少延迟阈值
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 间隔采样,每n帧处理一次,减少计算量
|
||||
if skip_counter > 0 and last_results is not None:
|
||||
skip_counter -= 1
|
||||
# 使用上次的检测结果
|
||||
annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
|
||||
processed_frame_queue.put((annotated_frame, timestamp))
|
||||
continue
|
||||
|
||||
process_start = time.time()
|
||||
|
||||
# 动态调整处理尺寸,根据队列积压情况
|
||||
resize_scale = 1.0
|
||||
if frame_queue.qsize() > frame_queue.maxsize * 0.7:
|
||||
resize_scale = 0.4 # 高负载时大幅降低分辨率
|
||||
elif frame_queue.qsize() > frame_queue.maxsize * 0.5:
|
||||
resize_scale = 0.6 # 中等负载时适度降低分辨率
|
||||
elif frame_queue.qsize() > frame_queue.maxsize * 0.3:
|
||||
resize_scale = 0.8 # 轻微负载时轻微降低分辨率
|
||||
|
||||
# 调整图像尺寸以加快处理
|
||||
if resize_scale < 1.0:
|
||||
process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale)
|
||||
else:
|
||||
process_frame = frame
|
||||
|
||||
# 执行推理
|
||||
try:
|
||||
results = ov_model(process_frame, classes=cls, show=False)
|
||||
last_results = results[0] # 保存检测结果用于后续帧
|
||||
|
||||
# 如果尺寸调整过,需要将结果转换回原始尺寸
|
||||
if resize_scale < 1.0:
|
||||
# 绘制检测框
|
||||
annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5),
|
||||
(frame.shape[1], frame.shape[0]))
|
||||
else:
|
||||
annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5)
|
||||
|
||||
# 在负载高时启用跳帧处理
|
||||
if frame_queue.qsize() > frame_queue.maxsize * 0.5:
|
||||
skip_counter = max_skip
|
||||
except Exception as infer_err:
|
||||
print(f"推理错误: {infer_err}")
|
||||
if last_results is not None:
|
||||
# 使用上次的结果
|
||||
annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
|
||||
else:
|
||||
# 如果没有上次的结果,简单返回原始帧
|
||||
annotated_frame = frame.copy()
|
||||
|
||||
process_end = time.time()
|
||||
process_times.append(process_end - process_start)
|
||||
if len(process_times) > 30:
|
||||
process_times.pop(0)
|
||||
|
||||
if frame_count % 30 == 0:
|
||||
avg_process_time = sum(process_times) / len(process_times)
|
||||
fps = 1.0 / avg_process_time if avg_process_time > 0 else 0
|
||||
print(f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time*1000:.2f}ms, 队列大小: {frame_queue.qsize()}, 缩放比例: {resize_scale:.2f}")
|
||||
|
||||
processed_frame_queue.put((annotated_frame, timestamp))
|
||||
error_count = 0 # 成功处理后重置错误计数
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"处理帧错误: {e}")
|
||||
if error_count >= max_errors:
|
||||
print(f"连续处理错误达到{max_errors}次,正在停止处理...")
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
def write_frames(processed_frame_queue, pipe, size):
|
||||
"""写入帧的线程,添加平滑处理"""
|
||||
last_write_time = time.time()
|
||||
target_time_per_frame = 1.0 / 30.0 # 30fps
|
||||
pipe_error_count = 0 # 添加错误计数
|
||||
max_pipe_errors = 3 # 最大容忍错误数
|
||||
frame_count = 0
|
||||
last_fps_time = time.time()
|
||||
skipped_frames = 0
|
||||
|
||||
# 使用队列存储最近几帧,用于在需要时进行插值
|
||||
recent_frames = []
|
||||
max_recent_frames = 5 # 增加缓存帧数量,提高平滑性
|
||||
|
||||
# 使用双缓冲机制提高写入速度
|
||||
buffer1 = bytearray(size[0] * size[1] * 3)
|
||||
buffer2 = bytearray(size[0] * size[1] * 3)
|
||||
current_buffer = buffer1
|
||||
|
||||
# 帧率控制参数
|
||||
min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔
|
||||
max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
# 获取处理后的帧,超时时间较短以便更平滑地处理
|
||||
frame, timestamp = processed_frame_queue.get(timeout=0.05)
|
||||
|
||||
# 存储最近的帧用于插值
|
||||
recent_frames.append(frame)
|
||||
if len(recent_frames) > max_recent_frames:
|
||||
recent_frames.pop(0)
|
||||
|
||||
current_time = time.time()
|
||||
elapsed = current_time - last_write_time
|
||||
|
||||
# 如果两帧间隔太短,考虑合并或跳过
|
||||
if elapsed < min_frame_interval and len(recent_frames) >= 2:
|
||||
skipped_frames += 1
|
||||
continue
|
||||
|
||||
# 如果两帧间隔太长,进行插值平滑
|
||||
if elapsed > max_frame_interval and len(recent_frames) >= 2:
|
||||
# 计算需要插入的帧数
|
||||
frames_to_insert = min(int(elapsed / target_time_per_frame), 3)
|
||||
|
||||
for i in range(frames_to_insert):
|
||||
# 创建插值帧
|
||||
weight = (i + 1) / (frames_to_insert + 1)
|
||||
interpolated_frame = cv2.addWeighted(recent_frames[-2], 1-weight, recent_frames[-1], weight, 0)
|
||||
|
||||
# 切换双缓冲
|
||||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||||
|
||||
# 高效调整大小并写入
|
||||
interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
|
||||
img_bytes = interpolated_resized.tobytes()
|
||||
|
||||
# 写入管道
|
||||
pipe.stdin.write(img_bytes)
|
||||
pipe.stdin.flush()
|
||||
|
||||
# 切换双缓冲
|
||||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||||
|
||||
# 正常写入当前帧 - 使用高效的调整大小方法
|
||||
resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR)
|
||||
img_bytes = resized_frame.tobytes()
|
||||
pipe.stdin.write(img_bytes)
|
||||
pipe.stdin.flush()
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 30 == 0:
|
||||
current_fps_time = time.time()
|
||||
fps = 30 / (current_fps_time - last_fps_time)
|
||||
print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}")
|
||||
last_fps_time = current_fps_time
|
||||
skipped_frames = 0
|
||||
|
||||
last_write_time = time.time()
|
||||
pipe_error_count = 0 # 成功写入后重置错误计数
|
||||
|
||||
except queue.Empty:
|
||||
# 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅
|
||||
if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame:
|
||||
try:
|
||||
# 创建插值帧
|
||||
interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0)
|
||||
|
||||
# 切换双缓冲
|
||||
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
|
||||
|
||||
resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
|
||||
img_bytes = resized_frame.tobytes()
|
||||
pipe.stdin.write(img_bytes)
|
||||
pipe.stdin.flush()
|
||||
last_write_time = time.time()
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"写入帧错误: {e}")
|
||||
pipe_error_count += 1
|
||||
if pipe_error_count >= max_pipe_errors:
|
||||
print("FFmpeg管道错误过多,正在终止进程...")
|
||||
stop_event.set() # 主动结束所有线程
|
||||
break
|
||||
|
||||
def startAIVideo2(video_path, output_url, m1, cls, confidence):
|
||||
rtmp = output_url
|
||||
setIfAI(True)
|
||||
|
||||
cap = None
|
||||
pipe = None
|
||||
read_thread = None
|
||||
process_thread = None
|
||||
write_thread = None
|
||||
ov_model = None
|
||||
|
||||
try:
|
||||
|
||||
global frame_queue, processed_frame_queue, stop_event
|
||||
stop_event = Event()
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = "4"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print("预加载YOLO模型...")
|
||||
model_params = {}
|
||||
try:
|
||||
test_model = YOLO(m1)
|
||||
if hasattr(test_model, "task"):
|
||||
model_params["task"] = "detect"
|
||||
if torch.cuda.is_available():
|
||||
model_params["half"] = True
|
||||
import inspect
|
||||
if "verbose" in inspect.signature(YOLO.__init__).parameters:
|
||||
model_params["verbose"] = False
|
||||
except Exception as e:
|
||||
print(f"参数检测失败,将使用默认参数: {e}")
|
||||
model_params = {}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 3:
|
||||
try:
|
||||
ov_model = YOLO(m1, **model_params)
|
||||
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
|
||||
for _ in range(3):
|
||||
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
|
||||
print("YOLO模型加载成功并预热完成")
|
||||
break
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
print(f"YOLO模型加载失败(尝试 {retry_count}/3): {e}")
|
||||
if "unexpected keyword" in str(e):
|
||||
param = str(e).split("'")[-2]
|
||||
if param in model_params:
|
||||
print(f"移除不支持的参数: {param}")
|
||||
del model_params[param]
|
||||
time.sleep(2)
|
||||
|
||||
if ov_model is None:
|
||||
raise Exception("无法加载YOLO模型")
|
||||
|
||||
ov_model.conf = confidence
|
||||
|
||||
print(f"正在连接视频流: {video_path}")
|
||||
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 5)
|
||||
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise Exception(f"无法打开视频流: {video_path}")
|
||||
|
||||
try:
|
||||
cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0)
|
||||
except Exception as e:
|
||||
print(f"无法设置自动曝光参数: {e}")
|
||||
|
||||
frame_queue = queue.Queue(maxsize=80)
|
||||
processed_frame_queue = queue.Queue(maxsize=40)
|
||||
|
||||
size = (1920, 1080)
|
||||
sizeStr = f"{size[0]}x{size[1]}"
|
||||
|
||||
command = [
|
||||
'ffmpeg', '-y',
|
||||
'-f', 'rawvideo', '-vcodec', 'rawvideo',
|
||||
'-pix_fmt', 'bgr24',
|
||||
'-s', sizeStr,
|
||||
'-r', '30',
|
||||
'-i', '-',
|
||||
'-c:v', 'libx264',
|
||||
'-preset', 'ultrafast',
|
||||
'-tune', 'zerolatency',
|
||||
'-x264-params', 'sei=0',
|
||||
'-pix_fmt', 'yuv420p',
|
||||
'-f', 'flv',
|
||||
'-g', '30',
|
||||
'-keyint_min', '30',
|
||||
'-sc_threshold', '0',
|
||||
'-b:v', '2500k',
|
||||
'-maxrate', '3000k',
|
||||
'-bufsize', '3000k',
|
||||
'-threads', '4',
|
||||
'-vsync', '1',
|
||||
rtmp
|
||||
]
|
||||
|
||||
|
||||
print(f"启动FFmpeg推流到: {rtmp}")
|
||||
pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
bufsize=10*1024*1024)
|
||||
|
||||
def monitor_ffmpeg_output(pipe):
|
||||
while not stop_event.is_set():
|
||||
line = pipe.stderr.readline().decode('utf-8', errors='ignore')
|
||||
if line and ("error" in line.lower()):
|
||||
print(f"FFmpeg错误: {line.strip()}")
|
||||
if "Cannot open connection" in line:
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
Thread(target=monitor_ffmpeg_output, args=(pipe,), daemon=True).start()
|
||||
|
||||
def read_frames(cap, frame_queue):
|
||||
while not stop_event.is_set():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("读取失败")
|
||||
break
|
||||
timestamp = time.time()
|
||||
try:
|
||||
frame_queue.put((timestamp, frame), timeout=1)
|
||||
except queue.Full:
|
||||
print("帧队列满,跳帧")
|
||||
|
||||
def process_frames(frame_queue, processed_frame_queue, model, cls, confidence):
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
timestamp, frame = frame_queue.get(timeout=1)
|
||||
result = model.predict(source=frame, classes=cls, conf=confidence, verbose=False)
|
||||
processed = result[0].plot()
|
||||
processed_frame_queue.put((timestamp, processed), timeout=1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"处理帧错误: {e}")
|
||||
|
||||
def write_frames(processed_frame_queue, pipe, size):
|
||||
last_timestamp = 0
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
timestamp, frame = processed_frame_queue.get(timeout=1)
|
||||
if timestamp < last_timestamp:
|
||||
print(f"跳过闪回帧 {timestamp} < {last_timestamp}")
|
||||
continue
|
||||
last_timestamp = timestamp
|
||||
frame = cv2.resize(frame, size)
|
||||
pipe.stdin.write(frame.tobytes())
|
||||
except Exception as e:
|
||||
print(f"写入帧错误: {e}")
|
||||
break
|
||||
|
||||
read_thread = Thread(target=read_frames, args=(cap, frame_queue), daemon=True, name="ReadThread")
|
||||
process_thread = Thread(target=process_frames, args=(frame_queue, processed_frame_queue, ov_model, cls, confidence), daemon=True, name="ProcessThread")
|
||||
write_thread = Thread(target=write_frames, args=(processed_frame_queue, pipe, size), daemon=True, name="WriteThread")
|
||||
|
||||
print("开始推流处理...")
|
||||
read_thread.start()
|
||||
process_thread.start()
|
||||
write_thread.start()
|
||||
|
||||
last_check = time.time()
|
||||
while getIfAI() and not stop_event.is_set():
|
||||
if not all([t.is_alive() for t in [read_thread, process_thread, write_thread]]):
|
||||
print("检测到线程停止,退出")
|
||||
stop_event.set()
|
||||
break
|
||||
if pipe.poll() is not None:
|
||||
print("FFmpeg退出")
|
||||
stop_event.set()
|
||||
break
|
||||
if time.time() - last_check > 30:
|
||||
print(f"输入队列: {frame_queue.qsize()}/{frame_queue.maxsize} | 输出队列: {processed_frame_queue.qsize()}/{processed_frame_queue.maxsize}")
|
||||
last_check = time.time()
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
print("清理资源...")
|
||||
stop_event.set()
|
||||
setIfAI(False)
|
||||
|
||||
for t in [read_thread, process_thread, write_thread]:
|
||||
if t and t.is_alive():
|
||||
t.join(timeout=2)
|
||||
|
||||
try:
|
||||
if cap: cap.release()
|
||||
if pipe:
|
||||
try:
|
||||
import signal
|
||||
os.kill(pipe.pid, signal.SIGTERM)
|
||||
except: pass
|
||||
pipe.stdin.close()
|
||||
pipe.terminate()
|
||||
try:
|
||||
pipe.wait(timeout=2)
|
||||
except:
|
||||
pipe.kill()
|
||||
except Exception as e:
|
||||
print(f"释放资源时出错: {e}")
|
||||
|
||||
try:
|
||||
cv2.destroyAllWindows()
|
||||
except:
|
||||
pass
|
||||
print("资源释放完毕")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sn = "1581F6QAD243C00BP71E"
|
||||
video_path = f"rtmp://222.212.85.86:1935/live/{sn}"
|
||||
# FFmpeg 推流地址
|
||||
rtmp = f"rtmp://222.212.85.86:1935/live/{sn}ai"
|
||||
|
||||
try:
|
||||
startAIVideo2(video_path, rtmp, "best.pt", [0, 1, 2, 3, 4],0.4)
|
||||
except KeyboardInterrupt:
|
||||
print("程序被用户中断")
|
||||
stopAIVideo()
|
||||
except Exception as e:
|
||||
print(f"程序异常: {e}")
|
||||
|
BIN
ai2/fire.pt
Normal file
BIN
ai2/fire.pt
Normal file
Binary file not shown.
BIN
ai2/gdaq.pt
Normal file
BIN
ai2/gdaq.pt
Normal file
Binary file not shown.
BIN
ai2/last.pt
Normal file
BIN
ai2/last.pt
Normal file
Binary file not shown.
48
ai2/minio_helper.py
Normal file
48
ai2/minio_helper.py
Normal file
@ -0,0 +1,48 @@
|
||||
|
||||
import os
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
bucket="300bdf2b-a150-406e-be63-d28bd29b409f"
|
||||
# 替换为你的MinIO服务器地址、访问密钥和秘密密钥
|
||||
def getClient():
|
||||
minio_client = Minio(
|
||||
"222.212.85.86:9000",
|
||||
access_key="WuRenJi",
|
||||
secure=False,
|
||||
secret_key="WRJ@2024",)
|
||||
return minio_client
|
||||
|
||||
def getPath2(object):
|
||||
#dir="C:/sy/movies/"
|
||||
dir=os.getcwd()+"/"
|
||||
baseName=object
|
||||
s1=baseName.rfind("/")
|
||||
dir2=(dir+baseName[0:s1+1]).replace("/","\\")
|
||||
fName=baseName[s1+1:int(len(baseName))]
|
||||
os.makedirs(dir2, exist_ok=True)
|
||||
file_path = os.path.join(dir2, fName)
|
||||
return file_path
|
||||
|
||||
def upLoad(obj,path):
|
||||
try:
|
||||
minio_client=getClient()
|
||||
minio_client.fput_object(bucket, obj, path)
|
||||
return True
|
||||
except S3Error as e:
|
||||
return False
|
||||
|
||||
def downLoad(obj):
|
||||
path=getPath2(obj)
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
# 从MinIO的存储桶和对象名称下载
|
||||
try:
|
||||
minio_client=getClient()
|
||||
minio_client.fget_object(bucket, obj, path)
|
||||
return path
|
||||
except S3Error as e:
|
||||
return ""
|
||||
|
||||
if __name__ == '__main__':
|
||||
upLoad("aaa/yolo_api.py","yolo_api.py")
|
BIN
ai2/smoke.pt
Normal file
BIN
ai2/smoke.pt
Normal file
Binary file not shown.
BIN
ai2/trash.pt
Normal file
BIN
ai2/trash.pt
Normal file
Binary file not shown.
BIN
ai2/yanwu2.pt
Normal file
BIN
ai2/yanwu2.pt
Normal file
Binary file not shown.
BIN
ai2/yolo11n.pt
Normal file
BIN
ai2/yolo11n.pt
Normal file
Binary file not shown.
509
ai2/yolo_api copy.py
Normal file
509
ai2/yolo_api copy.py
Normal file
@ -0,0 +1,509 @@
|
||||
from sanic import Sanic, json
|
||||
from sanic.response import json as json_response
|
||||
from sanic.exceptions import Unauthorized, NotFound, SanicException
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from cv_video import startAIVideo,stopAIVideo,getIfAI
|
||||
from sanic_cors import CORS
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 服务状态标志
|
||||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||||
|
||||
# 配置类
|
||||
class Config:
|
||||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||||
MAX_ACTIVE_TASKS = 10
|
||||
DEFAULT_CONFIDENCE = 0.5
|
||||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamRequest:
|
||||
source_url: str
|
||||
push_url: str
|
||||
model_path: str
|
||||
detect_classes: List[str]
|
||||
confidence: float = Config.DEFAULT_CONFIDENCE
|
||||
|
||||
def validate(self) -> None:
|
||||
"""验证请求参数"""
|
||||
if not self.source_url or not self.push_url:
|
||||
raise ValueError("Source URL and Push URL are required")
|
||||
|
||||
if not self.detect_classes:
|
||||
raise ValueError("At least one detection class must be specified")
|
||||
if not 0 < self.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest':
|
||||
try:
|
||||
instance = cls(
|
||||
source_url=data['source_url'],
|
||||
push_url=data['push_url'],
|
||||
model_path=data['model_path'],
|
||||
detect_classes=data['detect_classes'],
|
||||
confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE)
|
||||
)
|
||||
instance.validate()
|
||||
return instance
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {str(e)}")
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self):
|
||||
self.active_tasks: Dict[str, Dict[str, Any]] = {}
|
||||
self.task_status: Dict[str, str] = {}
|
||||
self.task_timestamps: Dict[str, datetime] = {}
|
||||
|
||||
def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None:
|
||||
"""添加新任务"""
|
||||
if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS:
|
||||
raise ValueError("Maximum number of active tasks reached")
|
||||
|
||||
self.active_tasks[task_id] = task_info
|
||||
self.task_status[task_id] = "running"
|
||||
self.task_timestamps[task_id] = datetime.now()
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
"""移除任务"""
|
||||
if task_id in self.active_tasks:
|
||||
del self.active_tasks[task_id]
|
||||
del self.task_status[task_id]
|
||||
del self.task_timestamps[task_id]
|
||||
logger.info(f"Task {task_id} removed")
|
||||
|
||||
def get_task_info(self, task_id: str) -> Dict[str, Any]:
|
||||
"""获取任务信息"""
|
||||
if task_id not in self.active_tasks:
|
||||
raise NotFound("Task not found")
|
||||
|
||||
return {
|
||||
"task_info": self.active_tasks[task_id],
|
||||
"status": self.task_status[task_id],
|
||||
"start_time": self.task_timestamps[task_id].isoformat()
|
||||
}
|
||||
|
||||
def check_tasks_health(self) -> Dict[str, str]:
|
||||
"""检查任务健康状态"""
|
||||
unhealthy_tasks = {}
|
||||
for task_id in list(self.active_tasks.keys()):
|
||||
# 检查任务是否还在运行(通过getIfAI()函数)
|
||||
if not getIfAI():
|
||||
unhealthy_tasks[task_id] = "stopped"
|
||||
logger.warning(f"Task {task_id} appears to be stopped unexpectedly")
|
||||
return unhealthy_tasks
|
||||
|
||||
def mark_all_tasks_as_stopped(self):
|
||||
"""标记所有任务为已停止状态"""
|
||||
for task_id in list(self.active_tasks.keys()):
|
||||
self.task_status[task_id] = "stopped"
|
||||
logger.warning("已将所有任务标记为停止状态")
|
||||
|
||||
app = Sanic("YoloStreamService")
|
||||
CORS(app)
|
||||
task_manager = TaskManager()
|
||||
|
||||
async def safe_stop_ai_video():
|
||||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||||
try:
|
||||
await asyncio.to_thread(stopAIVideo)
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 标记服务状态为不健康
|
||||
service_status["is_healthy"] = False
|
||||
service_status["last_error"] = str(e)
|
||||
service_status["error_time"] = datetime.now().isoformat()
|
||||
|
||||
# 强制结束所有任务
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
|
||||
# 尝试通过其他方式杀死可能存在的进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
# 查找并终止ffmpeg子进程
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except Exception as child_e:
|
||||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||||
except Exception as kill_e:
|
||||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||||
|
||||
# 等待一段时间让系统恢复
|
||||
await asyncio.sleep(Config.RESTART_DELAY)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
return False
|
||||
|
||||
def verify_token(request) -> None:
|
||||
"""验证请求token"""
|
||||
token = request.headers.get('X-API-Token')
|
||||
if not token or token != Config.VALID_TOKEN:
|
||||
logger.warning("Invalid token attempt")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
@app.post("/ai/stream/detect")
|
||||
async def start_detection(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查服务健康状态
|
||||
if not service_status["is_healthy"]:
|
||||
logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||||
# 尝试恢复服务
|
||||
service_status["is_healthy"] = True
|
||||
|
||||
# 先停止所有现有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
logger.info(f"停止现有任务 {task_id} 以启动新任务")
|
||||
try:
|
||||
success = await safe_stop_ai_video()
|
||||
if success:
|
||||
task_manager.remove_task(task_id)
|
||||
else:
|
||||
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务时出错: {e}")
|
||||
# 继续执行,尝试启动新任务
|
||||
|
||||
# 解析并验证请求数据
|
||||
stream_request = StreamRequest.from_dict(request.json)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
# 修正视频源地址
|
||||
new_source_url = stream_request.source_url.replace("222.212.85.86", "192.168.10.5")
|
||||
new_push_url = stream_request.push_url.replace("222.212.85.86", "192.168.10.5")
|
||||
# 启动YOLO检测
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
startAIVideo,
|
||||
new_source_url,
|
||||
new_push_url,
|
||||
stream_request.model_path,
|
||||
stream_request.detect_classes,
|
||||
stream_request.confidence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"启动AI视频处理失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to start AI video processing: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# 记录任务信息
|
||||
task_manager.add_task(task_id, {
|
||||
"source_url": stream_request.source_url,
|
||||
"push_url": stream_request.push_url,
|
||||
"model_path": stream_request.model_path,
|
||||
"detect_classes": stream_request.detect_classes,
|
||||
"confidence": stream_request.confidence
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
"message": "Detection started successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.post("/ai/stream/<task_id>")
|
||||
async def stop_detection(request, task_id: str):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查任务是否存在
|
||||
try:
|
||||
task_info = task_manager.get_task_info(task_id)
|
||||
except NotFound:
|
||||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||||
|
||||
# 停止AI视频处理,使用安全的停止方法
|
||||
success = await safe_stop_ai_video()
|
||||
|
||||
# 即使停止失败,也要移除任务
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
logger.warning("虽然停止过程出现错误,但任务已被标记为结束")
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Task removal completed with warnings"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "Detection stopped successfully"
|
||||
})
|
||||
except NotFound as e:
|
||||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True)
|
||||
# 尝试标记任务为停止状态
|
||||
try:
|
||||
if task_id in task_manager.task_status:
|
||||
task_manager.task_status[task_id] = "error_during_stop"
|
||||
except:
|
||||
pass
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/stream/<task_id>")
|
||||
async def get_task_status(request, task_id: str):
|
||||
try:
|
||||
verify_token(request)
|
||||
task_info = task_manager.get_task_info(task_id)
|
||||
|
||||
# 检查任务是否真的在运行
|
||||
if not getIfAI() and task_info["status"] == "running":
|
||||
task_info["status"] = "stopped_unexpectedly"
|
||||
logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止")
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
**task_info
|
||||
})
|
||||
except NotFound as e:
|
||||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/stream/tasks")
|
||||
async def list_tasks(request):
|
||||
"""获取所有活动任务列表"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查所有任务的健康状态
|
||||
unhealthy_tasks = task_manager.check_tasks_health()
|
||||
for task_id, status in unhealthy_tasks.items():
|
||||
if task_id in task_manager.task_status:
|
||||
task_manager.task_status[task_id] = status
|
||||
|
||||
tasks = {
|
||||
task_id: task_manager.get_task_info(task_id)
|
||||
for task_id in task_manager.active_tasks.keys()
|
||||
}
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"tasks": tasks
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.post("/ai/stream/stopTasks")
|
||||
async def stop_all_detections(request):
|
||||
"""停止所有活动任务"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
if not task_manager.active_tasks:
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "No active tasks to stop"
|
||||
})
|
||||
|
||||
# 停止所有任务
|
||||
success = await safe_stop_ai_video()
|
||||
|
||||
# 无论成功与否,都移除所有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Tasks stopped with warnings"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "All detections stopped successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True)
|
||||
# 尝试标记所有任务为停止状态
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/health")
|
||||
async def health_check(request):
|
||||
"""服务健康检查端点"""
|
||||
try:
|
||||
# 不需要验证token,这个接口可以用于监控系统检查服务状态
|
||||
unhealthy_tasks = task_manager.check_tasks_health()
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"service": "running" if service_status["is_healthy"] else "degraded",
|
||||
"active_tasks": len(task_manager.active_tasks),
|
||||
"unhealthy_tasks": unhealthy_tasks,
|
||||
"last_error": service_status["last_error"],
|
||||
"error_time": service_status["error_time"],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查失败: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"service": "degraded",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, status=500)
|
||||
|
||||
@app.route("/ai/reset", methods=["POST"])
|
||||
async def reset_service(request):
|
||||
"""重置服务状态,清理所有任务和进程"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 尝试停止AI视频处理
|
||||
await safe_stop_ai_video()
|
||||
|
||||
# 清理所有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
service_status["last_error"] = None
|
||||
service_status["error_time"] = None
|
||||
|
||||
# 尝试清理可能存在的僵尸进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
zombie_count = 0
|
||||
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
if child.status() == psutil.STATUS_ZOMBIE:
|
||||
zombie_count += 1
|
||||
child.send_signal(signal.SIGKILL)
|
||||
except:
|
||||
pass
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": f"Service reset successful. Cleaned {zombie_count} zombie processes."
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"清理僵尸进程时出错: {e}")
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Service reset with warnings"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重置服务时出错: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to reset service: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
@app.route("/ai/stream/restart/<task_id>", methods=["POST"])
|
||||
async def restart_task(request, task_id: str):
|
||||
"""重启指定任务"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 获取任务信息
|
||||
try:
|
||||
task_info = task_manager.get_task_info(task_id)["task_info"]
|
||||
except NotFound:
|
||||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||||
|
||||
# 先停止任务
|
||||
success = await safe_stop_ai_video()
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
logger.warning("停止任务出现问题,尝试继续重启")
|
||||
|
||||
# 重新启动任务
|
||||
new_task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
startAIVideo,
|
||||
task_info["source_url"],
|
||||
task_info["push_url"],
|
||||
task_info["model_path"],
|
||||
task_info["detect_classes"],
|
||||
task_info["confidence"]
|
||||
)
|
||||
|
||||
# 记录新任务信息
|
||||
task_manager.add_task(new_task_id, task_info)
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"old_task_id": task_id,
|
||||
"new_task_id": new_task_id,
|
||||
"message": "Task restarted successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to restart task: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务时出错: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 保证服务启动前没有残留任务
|
||||
try:
|
||||
stopAIVideo()
|
||||
print("服务启动前清理完成")
|
||||
except:
|
||||
print("服务启动前清理失败,但仍将继续")
|
||||
|
||||
# 安装psutil库,用于进程管理
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
import subprocess
|
||||
import sys
|
||||
print("正在安装psutil库...")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
||||
|
||||
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True)
|
509
ai2/yolo_api.py
Normal file
509
ai2/yolo_api.py
Normal file
@ -0,0 +1,509 @@
|
||||
from sanic import Sanic, json
|
||||
from sanic.response import json as json_response
|
||||
from sanic.exceptions import Unauthorized, NotFound, SanicException
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from cv_video import startAIVideo,stopAIVideo,getIfAI
|
||||
from sanic_cors import CORS
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 服务状态标志
|
||||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||||
|
||||
# 配置类
|
||||
class Config:
|
||||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||||
MAX_ACTIVE_TASKS = 10
|
||||
DEFAULT_CONFIDENCE = 0.5
|
||||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamRequest:
|
||||
source_url: str
|
||||
push_url: str
|
||||
model_path: str
|
||||
detect_classes: List[str]
|
||||
confidence: float = Config.DEFAULT_CONFIDENCE
|
||||
|
||||
def validate(self) -> None:
|
||||
"""验证请求参数"""
|
||||
if not self.source_url or not self.push_url:
|
||||
raise ValueError("Source URL and Push URL are required")
|
||||
|
||||
if not self.detect_classes:
|
||||
raise ValueError("At least one detection class must be specified")
|
||||
if not 0 < self.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest':
|
||||
try:
|
||||
instance = cls(
|
||||
source_url=data['source_url'],
|
||||
push_url=data['push_url'],
|
||||
model_path=data['model_path'],
|
||||
detect_classes=data['detect_classes'],
|
||||
confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE)
|
||||
)
|
||||
instance.validate()
|
||||
return instance
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {str(e)}")
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self):
|
||||
self.active_tasks: Dict[str, Dict[str, Any]] = {}
|
||||
self.task_status: Dict[str, str] = {}
|
||||
self.task_timestamps: Dict[str, datetime] = {}
|
||||
|
||||
def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None:
|
||||
"""添加新任务"""
|
||||
if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS:
|
||||
raise ValueError("Maximum number of active tasks reached")
|
||||
|
||||
self.active_tasks[task_id] = task_info
|
||||
self.task_status[task_id] = "running"
|
||||
self.task_timestamps[task_id] = datetime.now()
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
"""移除任务"""
|
||||
if task_id in self.active_tasks:
|
||||
del self.active_tasks[task_id]
|
||||
del self.task_status[task_id]
|
||||
del self.task_timestamps[task_id]
|
||||
logger.info(f"Task {task_id} removed")
|
||||
|
||||
def get_task_info(self, task_id: str) -> Dict[str, Any]:
|
||||
"""获取任务信息"""
|
||||
if task_id not in self.active_tasks:
|
||||
raise NotFound("Task not found")
|
||||
|
||||
return {
|
||||
"task_info": self.active_tasks[task_id],
|
||||
"status": self.task_status[task_id],
|
||||
"start_time": self.task_timestamps[task_id].isoformat()
|
||||
}
|
||||
|
||||
def check_tasks_health(self) -> Dict[str, str]:
|
||||
"""检查任务健康状态"""
|
||||
unhealthy_tasks = {}
|
||||
for task_id in list(self.active_tasks.keys()):
|
||||
# 检查任务是否还在运行(通过getIfAI()函数)
|
||||
if not getIfAI():
|
||||
unhealthy_tasks[task_id] = "stopped"
|
||||
logger.warning(f"Task {task_id} appears to be stopped unexpectedly")
|
||||
return unhealthy_tasks
|
||||
|
||||
def mark_all_tasks_as_stopped(self):
|
||||
"""标记所有任务为已停止状态"""
|
||||
for task_id in list(self.active_tasks.keys()):
|
||||
self.task_status[task_id] = "stopped"
|
||||
logger.warning("已将所有任务标记为停止状态")
|
||||
|
||||
app = Sanic("YoloStreamService")
|
||||
CORS(app)
|
||||
task_manager = TaskManager()
|
||||
|
||||
async def safe_stop_ai_video():
|
||||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||||
try:
|
||||
await asyncio.to_thread(stopAIVideo)
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 标记服务状态为不健康
|
||||
service_status["is_healthy"] = False
|
||||
service_status["last_error"] = str(e)
|
||||
service_status["error_time"] = datetime.now().isoformat()
|
||||
|
||||
# 强制结束所有任务
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
|
||||
# 尝试通过其他方式杀死可能存在的进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
# 查找并终止ffmpeg子进程
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except Exception as child_e:
|
||||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||||
except Exception as kill_e:
|
||||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||||
|
||||
# 等待一段时间让系统恢复
|
||||
await asyncio.sleep(Config.RESTART_DELAY)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
return False
|
||||
|
||||
def verify_token(request) -> None:
|
||||
"""验证请求token"""
|
||||
token = request.headers.get('X-API-Token')
|
||||
if not token or token != Config.VALID_TOKEN:
|
||||
logger.warning("Invalid token attempt")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
@app.post("/ai/stream/detect")
|
||||
async def start_detection(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查服务健康状态
|
||||
if not service_status["is_healthy"]:
|
||||
logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||||
# 尝试恢复服务
|
||||
service_status["is_healthy"] = True
|
||||
|
||||
# 先停止所有现有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
logger.info(f"停止现有任务 {task_id} 以启动新任务")
|
||||
try:
|
||||
success = await safe_stop_ai_video()
|
||||
if success:
|
||||
task_manager.remove_task(task_id)
|
||||
else:
|
||||
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务时出错: {e}")
|
||||
# 继续执行,尝试启动新任务
|
||||
|
||||
# 解析并验证请求数据
|
||||
stream_request = StreamRequest.from_dict(request.json)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
# 修正视频源地址
|
||||
new_source_url = stream_request.source_url.replace("222.212.85.86", "192.168.10.5")
|
||||
new_push_url = stream_request.push_url.replace("222.212.85.86", "192.168.10.5")
|
||||
# 启动YOLO检测
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
startAIVideo,
|
||||
new_source_url,
|
||||
new_push_url,
|
||||
stream_request.model_path,
|
||||
stream_request.detect_classes,
|
||||
stream_request.confidence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"启动AI视频处理失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to start AI video processing: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# 记录任务信息
|
||||
task_manager.add_task(task_id, {
|
||||
"source_url": stream_request.source_url,
|
||||
"push_url": stream_request.push_url,
|
||||
"model_path": stream_request.model_path,
|
||||
"detect_classes": stream_request.detect_classes,
|
||||
"confidence": stream_request.confidence
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
"message": "Detection started successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.post("/ai/stream/<task_id>")
|
||||
async def stop_detection(request, task_id: str):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查任务是否存在
|
||||
try:
|
||||
task_info = task_manager.get_task_info(task_id)
|
||||
except NotFound:
|
||||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||||
|
||||
# 停止AI视频处理,使用安全的停止方法
|
||||
success = await safe_stop_ai_video()
|
||||
|
||||
# 即使停止失败,也要移除任务
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
logger.warning("虽然停止过程出现错误,但任务已被标记为结束")
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Task removal completed with warnings"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "Detection stopped successfully"
|
||||
})
|
||||
except NotFound as e:
|
||||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True)
|
||||
# 尝试标记任务为停止状态
|
||||
try:
|
||||
if task_id in task_manager.task_status:
|
||||
task_manager.task_status[task_id] = "error_during_stop"
|
||||
except:
|
||||
pass
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/stream/<task_id>")
|
||||
async def get_task_status(request, task_id: str):
|
||||
try:
|
||||
verify_token(request)
|
||||
task_info = task_manager.get_task_info(task_id)
|
||||
|
||||
# 检查任务是否真的在运行
|
||||
if not getIfAI() and task_info["status"] == "running":
|
||||
task_info["status"] = "stopped_unexpectedly"
|
||||
logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止")
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
**task_info
|
||||
})
|
||||
except NotFound as e:
|
||||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/stream/tasks")
|
||||
async def list_tasks(request):
|
||||
"""获取所有活动任务列表"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查所有任务的健康状态
|
||||
unhealthy_tasks = task_manager.check_tasks_health()
|
||||
for task_id, status in unhealthy_tasks.items():
|
||||
if task_id in task_manager.task_status:
|
||||
task_manager.task_status[task_id] = status
|
||||
|
||||
tasks = {
|
||||
task_id: task_manager.get_task_info(task_id)
|
||||
for task_id in task_manager.active_tasks.keys()
|
||||
}
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"tasks": tasks
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.post("/ai/stream/stopTasks")
|
||||
async def stop_all_detections(request):
|
||||
"""停止所有活动任务"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
if not task_manager.active_tasks:
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "No active tasks to stop"
|
||||
})
|
||||
|
||||
# 停止所有任务
|
||||
success = await safe_stop_ai_video()
|
||||
|
||||
# 无论成功与否,都移除所有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Tasks stopped with warnings"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "All detections stopped successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True)
|
||||
# 尝试标记所有任务为停止状态
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/health")
|
||||
async def health_check(request):
|
||||
"""服务健康检查端点"""
|
||||
try:
|
||||
# 不需要验证token,这个接口可以用于监控系统检查服务状态
|
||||
unhealthy_tasks = task_manager.check_tasks_health()
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"service": "running" if service_status["is_healthy"] else "degraded",
|
||||
"active_tasks": len(task_manager.active_tasks),
|
||||
"unhealthy_tasks": unhealthy_tasks,
|
||||
"last_error": service_status["last_error"],
|
||||
"error_time": service_status["error_time"],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查失败: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"service": "degraded",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, status=500)
|
||||
|
||||
@app.route("/ai/reset", methods=["POST"])
|
||||
async def reset_service(request):
|
||||
"""重置服务状态,清理所有任务和进程"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 尝试停止AI视频处理
|
||||
await safe_stop_ai_video()
|
||||
|
||||
# 清理所有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
service_status["last_error"] = None
|
||||
service_status["error_time"] = None
|
||||
|
||||
# 尝试清理可能存在的僵尸进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
zombie_count = 0
|
||||
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
if child.status() == psutil.STATUS_ZOMBIE:
|
||||
zombie_count += 1
|
||||
child.send_signal(signal.SIGKILL)
|
||||
except:
|
||||
pass
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": f"Service reset successful. Cleaned {zombie_count} zombie processes."
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"清理僵尸进程时出错: {e}")
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Service reset with warnings"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重置服务时出错: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to reset service: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
@app.route("/ai/stream/restart/<task_id>", methods=["POST"])
|
||||
async def restart_task(request, task_id: str):
|
||||
"""重启指定任务"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 获取任务信息
|
||||
try:
|
||||
task_info = task_manager.get_task_info(task_id)["task_info"]
|
||||
except NotFound:
|
||||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||||
|
||||
# 先停止任务
|
||||
success = await safe_stop_ai_video()
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
logger.warning("停止任务出现问题,尝试继续重启")
|
||||
|
||||
# 重新启动任务
|
||||
new_task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
startAIVideo,
|
||||
task_info["source_url"],
|
||||
task_info["push_url"],
|
||||
task_info["model_path"],
|
||||
task_info["detect_classes"],
|
||||
task_info["confidence"]
|
||||
)
|
||||
|
||||
# 记录新任务信息
|
||||
task_manager.add_task(new_task_id, task_info)
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"old_task_id": task_id,
|
||||
"new_task_id": new_task_id,
|
||||
"message": "Task restarted successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to restart task: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务时出错: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 保证服务启动前没有残留任务
|
||||
try:
|
||||
stopAIVideo()
|
||||
print("服务启动前清理完成")
|
||||
except:
|
||||
print("服务启动前清理失败,但仍将继续")
|
||||
|
||||
# 安装psutil库,用于进程管理
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
import subprocess
|
||||
import sys
|
||||
print("正在安装psutil库...")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
||||
|
||||
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True)
|
BIN
ai2/yolo_api.zip
Normal file
BIN
ai2/yolo_api.zip
Normal file
Binary file not shown.
540
ai2/yolo_api_HANGZOUAPI.py
Normal file
540
ai2/yolo_api_HANGZOUAPI.py
Normal file
@ -0,0 +1,540 @@
|
||||
from sanic import Sanic, json
|
||||
from sanic.response import json as json_response
|
||||
from sanic.exceptions import Unauthorized, NotFound, SanicException
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from cv_video import startAIVideo,stopAIVideo,getIfAI
|
||||
from sanic_cors import CORS
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 服务状态标志
|
||||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||||
|
||||
# 配置类
|
||||
class Config:
|
||||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||||
MAX_ACTIVE_TASKS = 10
|
||||
DEFAULT_CONFIDENCE = 0.5
|
||||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamRequest:
|
||||
source_url: str
|
||||
push_url: str
|
||||
model_path: str
|
||||
detect_classes: List[str]
|
||||
confidence: float = Config.DEFAULT_CONFIDENCE
|
||||
|
||||
def validate(self) -> None:
|
||||
"""验证请求参数"""
|
||||
if not self.source_url or not self.push_url:
|
||||
raise ValueError("Source URL and Push URL are required")
|
||||
|
||||
if not self.detect_classes:
|
||||
raise ValueError("At least one detection class must be specified")
|
||||
if not 0 < self.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest':
|
||||
try:
|
||||
instance = cls(
|
||||
source_url=data['source_url'],
|
||||
push_url=data['push_url'],
|
||||
model_path=data['model_path'],
|
||||
detect_classes=data['detect_classes'],
|
||||
confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE)
|
||||
)
|
||||
instance.validate()
|
||||
return instance
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {str(e)}")
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self):
|
||||
self.active_tasks: Dict[str, Dict[str, Any]] = {}
|
||||
self.task_status: Dict[str, str] = {}
|
||||
self.task_timestamps: Dict[str, datetime] = {}
|
||||
|
||||
def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None:
|
||||
"""添加新任务"""
|
||||
if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS:
|
||||
raise ValueError("Maximum number of active tasks reached")
|
||||
|
||||
self.active_tasks[task_id] = task_info
|
||||
self.task_status[task_id] = "running"
|
||||
self.task_timestamps[task_id] = datetime.now()
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
"""移除任务"""
|
||||
if task_id in self.active_tasks:
|
||||
del self.active_tasks[task_id]
|
||||
del self.task_status[task_id]
|
||||
del self.task_timestamps[task_id]
|
||||
logger.info(f"Task {task_id} removed")
|
||||
|
||||
def get_task_info(self, task_id: str) -> Dict[str, Any]:
|
||||
"""获取任务信息"""
|
||||
if task_id not in self.active_tasks:
|
||||
raise NotFound("Task not found")
|
||||
|
||||
return {
|
||||
"task_info": self.active_tasks[task_id],
|
||||
"status": self.task_status[task_id],
|
||||
"start_time": self.task_timestamps[task_id].isoformat()
|
||||
}
|
||||
|
||||
def check_tasks_health(self) -> Dict[str, str]:
|
||||
"""检查任务健康状态"""
|
||||
unhealthy_tasks = {}
|
||||
for task_id in list(self.active_tasks.keys()):
|
||||
# 检查任务是否还在运行(通过getIfAI()函数)
|
||||
if not getIfAI():
|
||||
unhealthy_tasks[task_id] = "stopped"
|
||||
logger.warning(f"Task {task_id} appears to be stopped unexpectedly")
|
||||
return unhealthy_tasks
|
||||
|
||||
def mark_all_tasks_as_stopped(self):
|
||||
"""标记所有任务为已停止状态"""
|
||||
for task_id in list(self.active_tasks.keys()):
|
||||
self.task_status[task_id] = "stopped"
|
||||
logger.warning("已将所有任务标记为停止状态")
|
||||
|
||||
app = Sanic("YoloStreamService")
|
||||
CORS(app)
|
||||
task_manager = TaskManager()
|
||||
|
||||
async def safe_stop_ai_video():
|
||||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||||
try:
|
||||
await asyncio.to_thread(stopAIVideo)
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 标记服务状态为不健康
|
||||
service_status["is_healthy"] = False
|
||||
service_status["last_error"] = str(e)
|
||||
service_status["error_time"] = datetime.now().isoformat()
|
||||
|
||||
# 强制结束所有任务
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
|
||||
# 尝试通过其他方式杀死可能存在的进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
# 查找并终止ffmpeg子进程
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except Exception as child_e:
|
||||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||||
except Exception as kill_e:
|
||||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||||
|
||||
# 等待一段时间让系统恢复
|
||||
await asyncio.sleep(Config.RESTART_DELAY)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
return False
|
||||
|
||||
def verify_token(request) -> None:
|
||||
"""验证请求token"""
|
||||
token = request.headers.get('X-API-Token')
|
||||
if not token or token != Config.VALID_TOKEN:
|
||||
logger.warning("Invalid token attempt")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
@app.post("/ai/stream/detect")
|
||||
async def start_detection(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查服务健康状态
|
||||
if not service_status["is_healthy"]:
|
||||
logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||||
service_status["is_healthy"] = True # 尝试恢复
|
||||
|
||||
# 先停止所有现有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
logger.info(f"停止现有任务 {task_id} 以启动新任务")
|
||||
try:
|
||||
success = await safe_stop_ai_video()
|
||||
if success:
|
||||
task_manager.remove_task(task_id)
|
||||
else:
|
||||
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务时出错: {e}")
|
||||
|
||||
# 解析并验证请求数据
|
||||
stream_request = StreamRequest.from_dict(request.json)
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 如果是“yanwu.pt”模型,调用外部接口获取 liveUrl
|
||||
if stream_request.model_path == "yanwu.pt":
|
||||
try:
|
||||
import aiohttp
|
||||
import tempfile
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 创建一个临时空白文件
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
|
||||
tmpfile.write(b'')
|
||||
|
||||
with open(tmpfile.name, 'rb') as f:
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field("file", f, filename="empty.txt", content_type="text/plain")
|
||||
form_data.add_field("model", "yanwu")
|
||||
form_data.add_field("taskType", "2")
|
||||
form_data.add_field("source", stream_request.source_url)
|
||||
form_data.add_field("notifyUrl", f"{stream_request.source_url}+ai")
|
||||
|
||||
url = "https://flightcontrol.huaiying-xunjian.com/prod-api/third/api/v1/task/startTask"
|
||||
async with session.post(url, data=form_data) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"外部服务状态码异常: {resp.status}")
|
||||
result = await resp.json()
|
||||
if result.get("code") != 0 or not result.get("data", {}).get("liveUrl"):
|
||||
raise Exception(f"接口响应错误: {result}")
|
||||
live_url = result["data"]["liveUrl"]
|
||||
logger.info(f"外部接口返回推流地址: {live_url}")
|
||||
stream_request.push_url = live_url # 替换推流地址
|
||||
except Exception as ext:
|
||||
logger.error(f"调用外部直播任务接口失败: {ext}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"调用直播任务接口失败: {str(ext)}"
|
||||
}, status=500)
|
||||
else:
|
||||
# 启动YOLO检测
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
startAIVideo,
|
||||
stream_request.source_url,
|
||||
stream_request.push_url,
|
||||
stream_request.model_path,
|
||||
stream_request.detect_classes,
|
||||
stream_request.confidence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"启动AI视频处理失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to start AI video processing: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# 记录任务信息
|
||||
task_manager.add_task(task_id, {
|
||||
"source_url": stream_request.source_url,
|
||||
"push_url": stream_request.push_url,
|
||||
"model_path": stream_request.model_path,
|
||||
"detect_classes": stream_request.detect_classes,
|
||||
"confidence": stream_request.confidence
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
"message": "Detection started successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
|
||||
@app.post("/ai/stream/<task_id>")
|
||||
async def stop_detection(request, task_id: str):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查任务是否存在
|
||||
try:
|
||||
task_info = task_manager.get_task_info(task_id)
|
||||
except NotFound:
|
||||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||||
|
||||
# 停止AI视频处理,使用安全的停止方法
|
||||
success = await safe_stop_ai_video()
|
||||
|
||||
# 即使停止失败,也要移除任务
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
logger.warning("虽然停止过程出现错误,但任务已被标记为结束")
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Task removal completed with warnings"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "Detection stopped successfully"
|
||||
})
|
||||
except NotFound as e:
|
||||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True)
|
||||
# 尝试标记任务为停止状态
|
||||
try:
|
||||
if task_id in task_manager.task_status:
|
||||
task_manager.task_status[task_id] = "error_during_stop"
|
||||
except:
|
||||
pass
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/stream/<task_id>")
|
||||
async def get_task_status(request, task_id: str):
|
||||
try:
|
||||
verify_token(request)
|
||||
task_info = task_manager.get_task_info(task_id)
|
||||
|
||||
# 检查任务是否真的在运行
|
||||
if not getIfAI() and task_info["status"] == "running":
|
||||
task_info["status"] = "stopped_unexpectedly"
|
||||
logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止")
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
**task_info
|
||||
})
|
||||
except NotFound as e:
|
||||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/stream/tasks")
|
||||
async def list_tasks(request):
|
||||
"""获取所有活动任务列表"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查所有任务的健康状态
|
||||
unhealthy_tasks = task_manager.check_tasks_health()
|
||||
for task_id, status in unhealthy_tasks.items():
|
||||
if task_id in task_manager.task_status:
|
||||
task_manager.task_status[task_id] = status
|
||||
|
||||
tasks = {
|
||||
task_id: task_manager.get_task_info(task_id)
|
||||
for task_id in task_manager.active_tasks.keys()
|
||||
}
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"tasks": tasks
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.post("/ai/stream/stopTasks")
|
||||
async def stop_all_detections(request):
|
||||
"""停止所有活动任务"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
if not task_manager.active_tasks:
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "No active tasks to stop"
|
||||
})
|
||||
|
||||
# 停止所有任务
|
||||
success = await safe_stop_ai_video()
|
||||
|
||||
# 无论成功与否,都移除所有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Tasks stopped with warnings"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": "All detections stopped successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True)
|
||||
# 尝试标记所有任务为停止状态
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
@app.get("/ai/health")
|
||||
async def health_check(request):
|
||||
"""服务健康检查端点"""
|
||||
try:
|
||||
# 不需要验证token,这个接口可以用于监控系统检查服务状态
|
||||
unhealthy_tasks = task_manager.check_tasks_health()
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"service": "running" if service_status["is_healthy"] else "degraded",
|
||||
"active_tasks": len(task_manager.active_tasks),
|
||||
"unhealthy_tasks": unhealthy_tasks,
|
||||
"last_error": service_status["last_error"],
|
||||
"error_time": service_status["error_time"],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查失败: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"service": "degraded",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, status=500)
|
||||
|
||||
@app.route("/ai/reset", methods=["POST"])
|
||||
async def reset_service(request):
|
||||
"""重置服务状态,清理所有任务和进程"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 尝试停止AI视频处理
|
||||
await safe_stop_ai_video()
|
||||
|
||||
# 清理所有任务
|
||||
for task_id in list(task_manager.active_tasks.keys()):
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
service_status["last_error"] = None
|
||||
service_status["error_time"] = None
|
||||
|
||||
# 尝试清理可能存在的僵尸进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
zombie_count = 0
|
||||
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
if child.status() == psutil.STATUS_ZOMBIE:
|
||||
zombie_count += 1
|
||||
child.send_signal(signal.SIGKILL)
|
||||
except:
|
||||
pass
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"message": f"Service reset successful. Cleaned {zombie_count} zombie processes."
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"清理僵尸进程时出错: {e}")
|
||||
return json_response({
|
||||
"status": "warning",
|
||||
"message": "Service reset with warnings"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重置服务时出错: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to reset service: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
@app.route("/ai/stream/restart/<task_id>", methods=["POST"])
|
||||
async def restart_task(request, task_id: str):
|
||||
"""重启指定任务"""
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 获取任务信息
|
||||
try:
|
||||
task_info = task_manager.get_task_info(task_id)["task_info"]
|
||||
except NotFound:
|
||||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||||
|
||||
# 先停止任务
|
||||
success = await safe_stop_ai_video()
|
||||
task_manager.remove_task(task_id)
|
||||
|
||||
if not success:
|
||||
logger.warning("停止任务出现问题,尝试继续重启")
|
||||
|
||||
# 重新启动任务
|
||||
new_task_id = str(uuid.uuid4())
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
startAIVideo,
|
||||
task_info["source_url"],
|
||||
task_info["push_url"],
|
||||
task_info["model_path"],
|
||||
task_info["detect_classes"],
|
||||
task_info["confidence"]
|
||||
)
|
||||
|
||||
# 记录新任务信息
|
||||
task_manager.add_task(new_task_id, task_info)
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"old_task_id": task_id,
|
||||
"new_task_id": new_task_id,
|
||||
"message": "Task restarted successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to restart task: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务时出错: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 保证服务启动前没有残留任务
|
||||
try:
|
||||
stopAIVideo()
|
||||
print("服务启动前清理完成")
|
||||
except:
|
||||
print("服务启动前清理失败,但仍将继续")
|
||||
|
||||
# 安装psutil库,用于进程管理
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
import subprocess
|
||||
import sys
|
||||
print("正在安装psutil库...")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
||||
|
||||
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True)
|
355
even/A_lot_of.py
Normal file
355
even/A_lot_of.py
Normal file
@ -0,0 +1,355 @@
|
||||
import threading
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import threading
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
import datetime
|
||||
#@yoooger
|
||||
#-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
# 定义颜色映射
|
||||
colors = {
|
||||
"0": (0, 255, 0), # 绿色
|
||||
"1": (0, 0, 255), # 蓝色
|
||||
"2": (255, 0, 0), # 红色
|
||||
"3": (255, 255, 0), # 黄色
|
||||
"4": (0, 255, 255), # 青色
|
||||
"5": (255, 0, 255), # 紫色
|
||||
"6": (128, 0, 0), # 紫色
|
||||
"7": (0, 128, 0), # 绿色
|
||||
"8": (0, 0, 128), # 蓝色
|
||||
"9": (128, 128, 0), # 黄色
|
||||
"10": (128, 0, 128), # 紫色
|
||||
"11": (0, 128, 128), # 青色
|
||||
}
|
||||
#-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
def draw_and_save(frame, boxes, save_folder, frame_number, model_id):
|
||||
"""绘制检测框并保存当前帧"""
|
||||
|
||||
for box in boxes:
|
||||
class_id, (x1, y1, x2, y2) = box
|
||||
color = colors[class_id] # 根据类别选择颜色
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 保存当前帧到指定文件夹
|
||||
cv2.imwrite(os.path.join(save_folder, f"model_{model_id}_frame_{frame_number}.jpg"), frame)
|
||||
#-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
def process_images(image_folder_path, model, save_folder, model_id, classes_str, event):
|
||||
"""处理文件夹中的图片,检测物体并触发自定义事件"""
|
||||
os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在)
|
||||
classes = list(map(int, classes_str.split(','))) # 将逗号分隔的字符串转换为整数列表
|
||||
|
||||
# 回调函数用于绘制并保存帧
|
||||
def callback(frame, boxes, save_folder, frame_number):
|
||||
draw_and_save(frame, boxes, save_folder, frame_number, model_id)
|
||||
|
||||
event.subscribe(callback) # 订阅绘制和保存作业
|
||||
|
||||
# 遍历文件夹中的每个图片文件
|
||||
for frame_number, image_name in enumerate(os.listdir(image_folder_path)):
|
||||
|
||||
|
||||
image_path = os.path.join(image_folder_path, image_name)
|
||||
frame = cv2.imread(image_path) # 读取图片
|
||||
if frame is None:
|
||||
continue # 跳过无法读取的图片
|
||||
|
||||
# 使用 YOLO 进行推理
|
||||
results = model(frame) # 输入当前帧
|
||||
boxes = [] # 定义一个空列表,用于存放检测框
|
||||
|
||||
# 处理 YOLO 的输出
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls) # 获取类别ID
|
||||
if class_id in classes: # 检查类别是否在指定类别数组中
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 获取框坐标
|
||||
boxes.append((class_id, (x1, y1, x2, y2)))
|
||||
|
||||
#-------------------------------------------------------------------------------------------------------------------------------------
|
||||
def process_video(video_path, model, save_folder, model_id, classes_str, event):
|
||||
"""处理视频,检测物体并触发自定义事件"""
|
||||
os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在)
|
||||
classes = list(map(int, classes_str.split(',')))# 将逗号分隔的字符串转换为整数列表
|
||||
cap = cv2.VideoCapture(video_path)# 打开视频文件,使用cv2.VideoCapture()函数,截取视频的帧
|
||||
skip_frames = 0 # 初始化跳帧数
|
||||
|
||||
# 修改callback函数以接受3个参数
|
||||
def callback(frame, boxes, save_folder):
|
||||
draw_and_save(frame, boxes, save_folder, int(cap.get(cv2.CAP_PROP_POS_FRAMES)), model_id)
|
||||
|
||||
event.subscribe(callback) # 订阅绘制和保存作业
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if skip_frames > 0: # 跳过指定帧数,设置skip_frames为0可取消跳帧
|
||||
skip_frames -= 1 # 倒数计时器
|
||||
continue
|
||||
|
||||
# 使用 YOLO 进行推理
|
||||
results = model(frame) # 输入当前帧
|
||||
boxes = [] # 定义一个空列表,用于存放检测框
|
||||
|
||||
# 处理 YOLO 的输出
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls) # 获取类别ID
|
||||
if class_id in classes: # 检查类别是否在指定类别数组中
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 获取框坐标
|
||||
boxes.append((class_id, (x1, y1, x2, y2)))
|
||||
|
||||
#如果检测到物体,则跳过30帧
|
||||
if boxes:
|
||||
event.trigger(frame, boxes, save_folder)
|
||||
skip_frames = 100 # 检测到物体后跳过30帧
|
||||
|
||||
cap.release()
|
||||
|
||||
def process_video_threshold(video_path, model, save_folder, model_id, classes_str, event, threshold):
|
||||
"""
|
||||
处理视频,并在指定区域绘制检测框,并在指定阈值触发事件
|
||||
参数说明:
|
||||
video_path: 视频路径(应当改为视频流??)
|
||||
model: YOLO 模型
|
||||
save_folder: 保存帧的文件夹
|
||||
model_id: 模型ID
|
||||
classes_str: 要检测的类别(字符串形式)
|
||||
event: 自定义事件(CustomEvent0、CustomEvent1 等)
|
||||
threshold: 触发事件的阈值
|
||||
"""
|
||||
os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在)
|
||||
classes = list(map(int, classes_str.split(','))) # 将逗号分隔的字符串转换为整数列表
|
||||
cap = cv2.VideoCapture(video_path) # 打开视频文件
|
||||
skip_frames = 0 # 初始化跳帧数
|
||||
|
||||
# 订阅绘制和保存作业
|
||||
event.subscribe(lambda frame, boxes, save_folder: draw_and_save(frame, boxes, save_folder, int(cap.get(cv2.CAP_PROP_POS_FRAMES)), model_id))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if skip_frames > 0: # 跳过指定帧数
|
||||
skip_frames -= 1
|
||||
continue
|
||||
|
||||
# 使用 YOLO 进行推理
|
||||
results = model(frame,
|
||||
conf=0.4, # 置信度阈值
|
||||
iou=0.5, # 交并比阈值
|
||||
) # 输入当前帧
|
||||
boxes = [] # 定义一个空列表,用于存放检测框
|
||||
|
||||
# 处理 YOLO 的输出
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls) # 获取类别ID
|
||||
if class_id in classes: # 检查类别是否在指定类别数组中
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 获取框坐标
|
||||
boxes.append((class_id, (x1, y1, x2, y2)))
|
||||
|
||||
# 计算区域内目标数量
|
||||
object_count = len(boxes)
|
||||
|
||||
# 如果目标数量超过阈值,则触发事件
|
||||
if object_count > threshold:
|
||||
event.trigger(frame, boxes, save_folder)
|
||||
skip_frames = 100 # 跳过100帧
|
||||
|
||||
cap.release()
|
||||
#------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
# 自定义事件类(用于订阅和触发作业,检测到物体时触发作业)
|
||||
class CustomEvent0:
|
||||
def __init__(self):
|
||||
self._listeners = []
|
||||
|
||||
def subscribe(self, listener):
|
||||
"""订阅一个作业,当事件触发时会执行这个作业"""
|
||||
self._listeners.append(listener)
|
||||
|
||||
def trigger(self, frame, boxes, save_folder):
|
||||
"""触发事件,执行所有订阅的作业"""
|
||||
for listener in self._listeners:
|
||||
listener(frame, boxes, save_folder)
|
||||
|
||||
#自定义事件类(用于订阅和触发作业,检测到物体大于一定数量时触发作业)
|
||||
class CustomEvent1:
|
||||
def __init__(self, threshold=5):
|
||||
"""初始化事件类,并设置触发阈值"""
|
||||
self._listeners = []
|
||||
self.threshold = threshold # 设置触发作业的物体数量阈值
|
||||
|
||||
def subscribe(self, listener):
|
||||
"""订阅一个作业,当事件触发时会执行这个作业"""
|
||||
self._listeners.append(listener)
|
||||
|
||||
def trigger(self, frame, boxes, save_folder):
|
||||
"""当检测物体数量超过阈值时,触发事件"""
|
||||
if len(boxes) >= self.threshold: # 检查是否超过阈值
|
||||
for listener in self._listeners:
|
||||
listener(frame, boxes, save_folder)
|
||||
|
||||
def in_area_trigger(self, frame, boxes, save_folder, class_id=None, x_min=100, y_min=100, x_max=500, y_max=500):
|
||||
"""检测屏幕范围内特定类别的物体数量是否超过阈值"""
|
||||
count_in_area = 0
|
||||
for box in boxes:
|
||||
obj_class_id, (x1, y1, x2, y2) = box
|
||||
# 检查类别(如果指定)和坐标范围
|
||||
if (class_id is None or obj_class_id == class_id) and x1 >= x_min and y1 >= y_min and x2 <= x_max and y2 <= y_max:
|
||||
count_in_area += 1
|
||||
|
||||
# 如果数量超过阈值,触发事件
|
||||
if count_in_area >= self.threshold:
|
||||
for listener in self._listeners:
|
||||
listener(frame, boxes, save_folder)
|
||||
|
||||
#自定义事件类(用于订阅和触发作业,检测到物体小于一定数量时触发作业)
|
||||
class CustomEvent2:
|
||||
def __init__(self, threshold=5):
|
||||
"""初始化事件类,并设置触发阈值"""
|
||||
self._listeners = []
|
||||
self.threshold = threshold # 设置触发作业的物体数量阈值
|
||||
|
||||
def subscribe(self, listener):
|
||||
"""订阅一个作业,当事件触发时会执行这个作业"""
|
||||
self._listeners.append(listener)
|
||||
|
||||
def trigger(self, frame, boxes, save_folder):
|
||||
"""当检测物体数量超过阈值时,触发事件"""
|
||||
if len(boxes) >= self.threshold: # 检查是否超过阈值
|
||||
for listener in self._listeners:
|
||||
listener(frame, boxes, save_folder)
|
||||
|
||||
#-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# 获取当前日期和时间精确到分钟作为run_name
|
||||
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||||
# 定义处理任务的函数
|
||||
def process_task(video_path, model_path, save_folder, classes_str, event_type, threshold=None, model_id=None):
|
||||
"""
|
||||
统一处理任务的函数
|
||||
"""
|
||||
model = YOLO(model_path) # 加载 YOLO 模型
|
||||
os.makedirs(save_folder, exist_ok=True) # 创建保存文件夹(如果不存在)
|
||||
|
||||
if threshold is not None:
|
||||
# 使用阈值处理视频
|
||||
threading.Thread(
|
||||
target=process_video_threshold,
|
||||
args=(video_path, model, save_folder, model_id, classes_str, event_type, threshold)
|
||||
).start()
|
||||
else:
|
||||
# 不使用阈值处理视频
|
||||
threading.Thread(
|
||||
target=process_video,
|
||||
args=(video_path, model, save_folder, model_id, classes_str, event_type)
|
||||
).start()
|
||||
|
||||
def Start_program(video_path, Start_String):
|
||||
"""
|
||||
启动程序,根据 Start_String 启动特定的模型
|
||||
:param video_path: 视频路径
|
||||
:param Start_String: 启动模型的字符串
|
||||
"""
|
||||
task_configurations = {
|
||||
|
||||
'''
|
||||
"模型名称": {
|
||||
"model_path": "模型路径",
|
||||
"save_folder": "保存文件夹路径",
|
||||
"classes": "检测物体的类别",
|
||||
"event": 自定义事件类实例, Customevent0:if have aim /1: aim number more than one number /2: aim number less than one number
|
||||
"threshold": 触发阈值(可选),
|
||||
"model_id": 模型ID(可选)
|
||||
}
|
||||
'''
|
||||
|
||||
# 出现了人
|
||||
"have_peoples": {
|
||||
"model_path": "yolo11n.pt",
|
||||
"save_folder": f"output/{current_date}_People_in_the_area",
|
||||
"classes": "0",
|
||||
"event": CustomEvent0(),
|
||||
"threshold": None,
|
||||
"model_id": 1,
|
||||
},
|
||||
#发现了人员聚集
|
||||
"many_peoples": {
|
||||
"model_path": "yolo11n.pt",
|
||||
"save_folder": f"output/{current_date}_Many_people",
|
||||
"classes": "0",
|
||||
"event": CustomEvent1(threshold=5),
|
||||
"threshold": 5,
|
||||
"model_id": 2,
|
||||
},
|
||||
#发现了无安全帽人员
|
||||
"no_helmet": {
|
||||
"model_path": "gdaq.pt",
|
||||
"save_folder": f"output/{current_date}_Have_no_helmet",
|
||||
"classes": "1",
|
||||
"event": CustomEvent0(),
|
||||
"threshold": None,
|
||||
"model_id": 3,
|
||||
},
|
||||
#无安全绳
|
||||
"no_safety_line": {
|
||||
"model_path": "gdaq.pt",
|
||||
"save_folder": f"output/{current_date}_Have_no_safety_line",
|
||||
"classes": "3",
|
||||
"event": CustomEvent0(),
|
||||
"threshold": None,
|
||||
"model_id": 4,
|
||||
},
|
||||
"smoke": {
|
||||
"model_path": "smoke.pt",
|
||||
"save_folder": f"output/{current_date}_Have_smoke",
|
||||
"classes": "1",
|
||||
"event": CustomEvent0(),
|
||||
"threshold": None,
|
||||
"model_id": 5,
|
||||
},
|
||||
#发现了烟雾
|
||||
"fire": {
|
||||
"model_path": "smoke.pt",
|
||||
"save_folder": f"output/{current_date}_Have_fire",
|
||||
"classes": "0",
|
||||
"event": CustomEvent0(),
|
||||
"threshold": None,
|
||||
"model_id": 6,
|
||||
},
|
||||
#护栏破损
|
||||
"HULAN_POSUN": {
|
||||
"model_path": "gdaq.pt",
|
||||
"save_folder": f"output/{current_date}_HULAN_POSUN",
|
||||
"classes": "6",
|
||||
"event": CustomEvent0(),
|
||||
"threshold": None,
|
||||
"model_id": 7,
|
||||
},
|
||||
}
|
||||
# 将 Start_String 拆分为列表
|
||||
start_flags = Start_String.split(",")
|
||||
|
||||
for flag in start_flags:
|
||||
if flag in task_configurations:
|
||||
config = task_configurations[flag]
|
||||
process_task(
|
||||
video_path,
|
||||
config["model_path"],
|
||||
config["save_folder"],
|
||||
config["classes"],
|
||||
config["event"],
|
||||
config["threshold"],
|
||||
config["model_id"], #we difine model_id to identify the model
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 使用线程处理视频
|
||||
video_path = r"D:\work\DJI_20241122093213_0043_V.mp4" # 替换为你的视频路径
|
||||
String = "have_peoples,many_peoples,no_helmet,no_safety_line" # 设置要启动的模型
|
||||
Start_program(video_path, String) # 启动程序
|
BIN
even/__pycache__/A_lot_of.cpython-312.pyc
Normal file
BIN
even/__pycache__/A_lot_of.cpython-312.pyc
Normal file
Binary file not shown.
BIN
even/__pycache__/colour_map.cpython-312.pyc
Normal file
BIN
even/__pycache__/colour_map.cpython-312.pyc
Normal file
Binary file not shown.
BIN
even/__pycache__/even_program_video.cpython-312.pyc
Normal file
BIN
even/__pycache__/even_program_video.cpython-312.pyc
Normal file
Binary file not shown.
BIN
even/__pycache__/even_program_video_back.cpython-312.pyc
Normal file
BIN
even/__pycache__/even_program_video_back.cpython-312.pyc
Normal file
Binary file not shown.
BIN
even/__pycache__/even_rule.cpython-312.pyc
Normal file
BIN
even/__pycache__/even_rule.cpython-312.pyc
Normal file
Binary file not shown.
BIN
even/gdaq.pt
Normal file
BIN
even/gdaq.pt
Normal file
Binary file not shown.
BIN
even/smoke.pt
Normal file
BIN
even/smoke.pt
Normal file
Binary file not shown.
BIN
even/yolo11n.pt
Normal file
BIN
even/yolo11n.pt
Normal file
Binary file not shown.
167
even/zhuanhuan1.py
Normal file
167
even/zhuanhuan1.py
Normal file
@ -0,0 +1,167 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pyproj import Transformer
|
||||
import math
|
||||
from PIL import Image
|
||||
from pylab import *
|
||||
from dataclasses import dataclass
|
||||
|
||||
def rotate_points_around_axis(points, axis, angle):
|
||||
# 将轴转换为单位向量
|
||||
axis =axis.squeeze()
|
||||
axis = axis / np.linalg.norm(axis)
|
||||
|
||||
# 将角度转换为弧度
|
||||
theta = np.radians(angle)
|
||||
|
||||
# 计算旋转矩阵
|
||||
K = np.array([[0, -axis[2], axis[1]],
|
||||
[axis[2], 0, -axis[0]],
|
||||
[-axis[1], axis[0], 0]])
|
||||
|
||||
I = np.eye(3) # 单位矩阵
|
||||
R = I + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K)
|
||||
|
||||
# 将点平移到旋转轴的起点
|
||||
|
||||
# 旋转多个点
|
||||
rotated_points = np.dot(R.T,points.T) # 矩阵乘法
|
||||
# 再平移回去
|
||||
|
||||
return rotated_points
|
||||
|
||||
def rotate_points(points, center, theta):
|
||||
"""
|
||||
批量旋转二维点
|
||||
:param points: 输入的二维点矩阵,形状为 (N, 2)
|
||||
:param center: 旋转中心 (x_c, y_c)
|
||||
:param theta: 旋转角度(弧度)
|
||||
:return: 旋转后的点矩阵,形状为 (N, 2)
|
||||
"""
|
||||
# 平移到旋转中心
|
||||
translated = points - center # 平移,使得中心在原点
|
||||
rotation_matrix = np.array([[np.cos(theta), np.sin(theta)],
|
||||
[-np.sin(theta), np.cos(theta)]])
|
||||
|
||||
# 应用旋转
|
||||
rotated = translated @ rotation_matrix.T # 矩阵乘法
|
||||
return rotated + center # 将旋转后的点移回原位置
|
||||
def compute_rotation_matrix(pitch, roll1, yaw):
|
||||
c2w_x = np.array([[1, 0, 0],
|
||||
[0, math.cos(pitch), -math.sin(pitch)],
|
||||
[0, math.sin(pitch), math.cos(pitch)]], dtype=np.float32)
|
||||
c2w_y = np.array([[math.cos(roll1), 0, math.sin(roll1)],
|
||||
[0, 1, 0],
|
||||
[-math.sin(roll1), 0, math.cos(roll1)]], dtype=np.float32)
|
||||
c2w_z = np.array([[math.cos(yaw), -math.sin(yaw), 0],
|
||||
[math.sin(yaw), math.cos(yaw), 0],
|
||||
[0, 0, 1]], dtype=np.float32)
|
||||
c2w = c2w_x @ c2w_y @ c2w_z
|
||||
# 按 Z-Y-X 顺序组合旋转矩阵
|
||||
y=np.array([[0],[0],[0]])
|
||||
c2w1 = np.hstack([c2w,y])
|
||||
new_row = np.array([[0, 0, 0, 1]])
|
||||
c2w1 = np.vstack([c2w1, new_row])
|
||||
return c2w, c2w1
|
||||
def create_transformation_matrix(tvec):
|
||||
transform_matrix = np.eye(4) # 创建单位矩阵
|
||||
transform_matrix[:3, 3] = tvec.flatten() # 平移向量
|
||||
return transform_matrix
|
||||
|
||||
def dms2dd(x,y,z):
|
||||
return x+y/60+z/3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectTrans:
|
||||
'''
|
||||
计算坐标
|
||||
'''
|
||||
def __init__(self, f_mm, sensor_width_mm, sensor_height_mm, image_width_px, image_height_px, k1, k2, k3, k4, k5):
|
||||
self.image_points = None
|
||||
self.tvec = None
|
||||
self.rvec = None
|
||||
self.object_point = None
|
||||
self.f_mm = f_mm # 物理焦距 (mm)
|
||||
self.sensor_width_mm= sensor_width_mm # 传感器宽度 (mm)
|
||||
self.sensor_height_mm=sensor_height_mm # 传感器高度 (mm)
|
||||
self.image_width_px= image_width_px # 图像宽度 (px)
|
||||
self.image_height_px= image_height_px # 图像高度 (px)
|
||||
self.Ji = np.array([[k1], [k2], [k3], [k4],[k5]], dtype=np.float32)
|
||||
self.camera_matrix = None
|
||||
|
||||
def apply_transformation(self, c2w, transform_matrix):
|
||||
# 齐次化
|
||||
homogeneous_points = np.hstack([self.object_point, np.ones((self.object_point.shape[0], 1))])
|
||||
# 旋转平移
|
||||
transformed_points = transform_matrix @ homogeneous_points.T
|
||||
transformed_points = c2w.T @ transformed_points
|
||||
return transformed_points[:3, :] # 返回前3列(x, y, z坐标)
|
||||
|
||||
def compute_camera_matrix(self): #返回内参矩阵
|
||||
"""计算并返回相机内参矩阵."""
|
||||
fx = floor((self.f_mm * self.image_width_px) / self.sensor_width_mm)
|
||||
fy = floor((self.f_mm * self.image_height_px) / self.sensor_height_mm)
|
||||
cx = self.image_width_px / 2
|
||||
cy = self.image_height_px / 2
|
||||
fx=3725
|
||||
fy=3725
|
||||
cx= 2640
|
||||
cy=1978
|
||||
# 构造内参矩阵
|
||||
self.camera_matrix = np.array([
|
||||
[fx, 0, cx],
|
||||
[0, fy, cy],
|
||||
[0, 0, 1]
|
||||
], dtype=np.float32)
|
||||
# def transform_Y(self):
|
||||
# num_rows = self.image_points.shape[0]
|
||||
# new_column = np.ones((num_rows, 1))
|
||||
# self.image_points[:,0] +=80
|
||||
# self.image_points[:,1] +=100
|
||||
# self.image_points = np.hstack([self.image_points, new_column])
|
||||
# transform = np.array([[math.cos(radians(-84.8)), -math.sin(radians(-84.8)), 2728*(1-math.cos(radians(-84.8)))+1980.*math.sin(radians(-84.8))],
|
||||
# [math.sin(radians(-84.8)), math.cos(radians(-84.8)), 1980*(1-math.cos(radians(-84.8)))-2728*math.sin(radians(-84.8))],
|
||||
# [0, 0, 1]], dtype=np.float32)
|
||||
# self.image_points = (transform @ self.image_points.T).T
|
||||
def projectPoints(self, object_point, tvec, rvec):
|
||||
self.object_point = object_point
|
||||
self.tvec=tvec
|
||||
self.rvec=np.radians(rvec)
|
||||
# transformer = Transformer.from_crs("epsg:4326", "epsg:4544")
|
||||
# for index, point in enumerate(object_point):
|
||||
# lat,lon,z = point
|
||||
# self.object_poin t[index,1],self.object_point[index,0] = transformer.transform(lat, lon)
|
||||
# self.tvec[1] ,self.tvec[0] = transformer.transform(tvec[0], tvec[1])
|
||||
self.tvec = -self.tvec
|
||||
# c2w = compute_rotation_matrix(np.pi - 0.1 / 180 * np.pi, 0.85 / 180 * np.pi, 0) # 计算旋转矩阵
|
||||
# transform_matrix = create_transformation_matrix(self.tvec) # 创建平移矩阵
|
||||
# self.object_point = self.apply_transformation(c2w, transform_matrix) # 旋转平移
|
||||
# self.object_point = rotate_points_around_axis(self.object_point, np.array([0.0148148, -0.00174533, 0.999901]), -rvec[1])
|
||||
c2w1, c2w = compute_rotation_matrix(np.pi-0.1/180*np.pi, -1/180*np.pi, 1/180*np.pi) #计算旋转矩阵
|
||||
transform_matrix = create_transformation_matrix(self.tvec) #创建平移矩阵
|
||||
self.object_point = self.apply_transformation(c2w, transform_matrix) #旋转平移
|
||||
axis = c2w1.T @ np.array([[0],[0],[1]],dtype=np.float32)
|
||||
self.object_point = rotate_points_around_axis(self.object_point.T,axis.T,- rvec[1])
|
||||
self.object_point = rotate_points_around_axis(self.object_point.T, np.array((0.9999,-0.00174533,-0.0148)), (rvec[0]+90))
|
||||
self.object_point = self.object_point.T
|
||||
self.compute_camera_matrix() #计算相机内参矩
|
||||
BUP=np.array([0,0,0],dtype=np.float32)
|
||||
self.rvec = np.array([0,0,0],dtype=np.float32)
|
||||
self.image_points, _ = cv2.projectPoints(self.object_point, self.rvec, BUP, self.camera_matrix, self.Ji)
|
||||
self.image_points = self.image_points.squeeze()
|
||||
# self.transform_Y()
|
||||
# self.image_points = rotate_points(self.image_points, np.array([2640, 1978]), rvec[1]/180*np.pi)
|
||||
return self.image_points
|
||||
|
||||
def plot(self):
|
||||
x_coords=self.image_points[:,0]
|
||||
y_coords=self.image_points[:,1]
|
||||
print(self.image_points,self.rvec)
|
||||
# 打开图像并绘制投影点
|
||||
im = Image.open('DJI_正.jpeg')
|
||||
imshow(im)
|
||||
# plot(x_coords, y_coords, 'r*')
|
||||
plot(2647.02,1969.38,'b*')# 绘制红色星标
|
||||
show()
|
||||
cv2.imwrite('result1.jp eg', im)
|
Loading…
x
Reference in New Issue
Block a user