259 lines
8.7 KiB
Python
259 lines
8.7 KiB
Python
import cv2
|
||
import subprocess
|
||
from threading import Thread, Lock, Event
|
||
import time
|
||
import queue
|
||
import numpy as np
|
||
from ultralytics import YOLO
|
||
import datetime
|
||
import os
|
||
import uuid
|
||
|
||
# ---------- 会话管理器 ----------
|
||
stream_sessions = {}
|
||
|
||
|
||
def start_video_session(video_path, output_url, model_path, cls, confidence,cls2=[],push = False):
|
||
"""
|
||
启动一个视频流处理会话,返回 session_id。
|
||
"""
|
||
session_id = str(uuid.uuid4())
|
||
stop_event = Event()
|
||
frame_q = queue.Queue(maxsize=5)
|
||
processed_q = queue.Queue(maxsize=5)
|
||
model_container = {
|
||
'model': YOLO(model_path),
|
||
'lock': Lock()
|
||
}
|
||
|
||
# 启动后台线程
|
||
thread = Thread(
|
||
target=stream_worker,
|
||
args= (video_path, output_url, cls, confidence, stop_event, frame_q, processed_q, model_container,cls2,push),
|
||
daemon=True
|
||
)
|
||
thread.start()
|
||
|
||
stream_sessions[session_id] = {
|
||
'thread': thread,
|
||
'stop_event': stop_event,
|
||
'model_container': model_container
|
||
}
|
||
print(f"会话 {session_id} 已启动,视频源: {video_path}")
|
||
return session_id
|
||
|
||
|
||
def stop_video_session(session_id):
|
||
"""
|
||
停止指定会话。
|
||
"""
|
||
sess = stream_sessions.get(session_id)
|
||
if not sess:
|
||
print(f"会话 {session_id} 不存在。")
|
||
return
|
||
sess['stop_event'].set()
|
||
sess['thread'].join()
|
||
print(f"会话 {session_id} 已停止。")
|
||
del stream_sessions[session_id]
|
||
|
||
|
||
def switch_model_session(session_id, new_model_path):
|
||
"""
|
||
动态切换指定会话的模型权重。
|
||
"""
|
||
sess = stream_sessions.get(session_id)
|
||
if not sess:
|
||
print(f"会话 {session_id} 不存在,无法切换模型。")
|
||
return
|
||
with sess['model_container']['lock']:
|
||
sess['model_container']['model'] = YOLO(new_model_path)
|
||
print(f"会话 {session_id} 模型已切换为: {new_model_path}")
|
||
|
||
def _read_frames(cap, frame_q, stop_event):
|
||
""" 读取帧,保持30fps,丢弃延迟帧,若连续2秒无有效帧则触发停止 """
|
||
last = time.time()
|
||
target = 1.0 / 30.0
|
||
no_frame_start = None # 开始无法读取帧的时间点
|
||
|
||
while not stop_event.is_set():
|
||
now = time.time()
|
||
if now - last < target:
|
||
time.sleep(target - (now - last))
|
||
continue
|
||
if frame_q.qsize() < frame_q.maxsize:
|
||
ret, frm = cap.read()
|
||
if not ret:
|
||
if no_frame_start is None:
|
||
no_frame_start = time.time()
|
||
elif time.time() - no_frame_start >= 2:
|
||
print("⚠️ 超过2秒未获取到帧,自动停止会话。")
|
||
stop_event.set()
|
||
break
|
||
continue
|
||
else:
|
||
no_frame_start = None # 成功读取帧,重置计时
|
||
frame_q.put((frm, time.time()))
|
||
last = time.time()
|
||
else:
|
||
time.sleep(0.001)
|
||
|
||
|
||
|
||
def _process_frames(frame_q, processed_q, cls, confidence, model_container, stop_event):
|
||
""" 检测并绘制 """
|
||
skip_counter = 0
|
||
while not stop_event.is_set():
|
||
try:
|
||
frm, ts = frame_q.get(timeout=1)
|
||
if time.time() - ts > 0.5:
|
||
continue
|
||
skip_counter = (skip_counter + 1) % 2
|
||
if skip_counter != 0:
|
||
continue
|
||
|
||
with model_container['lock']:
|
||
results = model_container['model'].track(frm, persist=True, classes=cls, conf=confidence)
|
||
counts = {}
|
||
for r in results:
|
||
if hasattr(r, 'boxes'):
|
||
for b in r.boxes:
|
||
cid = int(b.cls[0])
|
||
counts[cid] = counts.get(cid, 0) + 1
|
||
ann = results[0].plot(conf=False, line_width=1, font_size=1.5)
|
||
y_off = 30
|
||
for cid, cnt in counts.items():
|
||
txt = f"Class {cid}: {cnt}"
|
||
cv2.putText(ann, txt, (10, y_off), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
|
||
y_off += 25
|
||
processed_q.put((ann, ts))
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
print(f"_process_frames 错误: {e}")
|
||
|
||
|
||
def _write_frames(processed_q, pipe, size, stop_event):
|
||
""" 推流写入 """
|
||
last = time.time()
|
||
target = 1.0 / 30.0
|
||
while not stop_event.is_set():
|
||
try:
|
||
frm, ts = processed_q.get(timeout=1)
|
||
now = time.time()
|
||
if now - last < target:
|
||
time.sleep(target - (now - last))
|
||
img = cv2.resize(frm, size, interpolation=cv2.INTER_LINEAR)
|
||
pipe.stdin.write(img.tobytes())
|
||
last = time.time()
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
print(f"_write_frames 错误: {e}")
|
||
break
|
||
|
||
|
||
def _cls2_find(video_path, confidence, save_dir, stop_event):
|
||
"""
|
||
指定类别检测并上传数据库,检测到每个跟踪目标只保存一次后续忽略。
|
||
"""
|
||
model = YOLO("gdaq.pt")
|
||
cls2 = [2, 4]
|
||
cap = cv2.VideoCapture(video_path)
|
||
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||
|
||
# 存储已经保存过的 track_id
|
||
seen_tracks = set()
|
||
|
||
while cap.isOpened() and not stop_event.is_set():
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# 使用 track 来获取每个目标的 track_id
|
||
results = model.track(frame, persist=True, classes=cls2, conf=confidence, show=False)
|
||
for res in results:
|
||
for box in res.boxes:
|
||
tid = int(box.id[0]) # 跟踪ID
|
||
cls_id = int(box.cls[0])
|
||
if cls_id in cls2 and tid not in seen_tracks:
|
||
seen_tracks.add(tid)
|
||
fn = os.path.join(
|
||
save_dir,
|
||
f"track_{tid}_frame_{int(cap.get(cv2.CAP_PROP_POS_FRAMES))}.jpg"
|
||
)
|
||
cv2.imwrite(fn, frame)
|
||
print(f"保存: {fn}")
|
||
# TODO: 在这里做上传和数据库写入
|
||
# 小睡一会,防止 CPU 占用过高
|
||
time.sleep(0.01)
|
||
|
||
cap.release()
|
||
|
||
|
||
# ---------- 核心处理逻辑 ----------
|
||
|
||
def stream_worker(video_path, output_url, cls, confidence, stop_event, frame_q, processed_q, model_container, cls2=[],push = False):
|
||
"""
|
||
会话主函数:打开流、预热模型、启动子线程、等待停止。
|
||
"""
|
||
# 创建保存目录(cls2 功能)
|
||
date_str = datetime.datetime.now().strftime("%Y%m%d")
|
||
save_dir = f"AIResults_{date_str}"
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 打开视频流
|
||
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
|
||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 2)
|
||
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
|
||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
|
||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
|
||
|
||
# 预加载并预热模型
|
||
print(f"[{video_path}] 预加载YOLO模型...")
|
||
with model_container['lock']:
|
||
dummy = np.zeros((1080, 1920, 3), dtype=np.uint8)
|
||
model_container['model'].track(dummy, classes=cls, conf=confidence, show=False)
|
||
|
||
# 构造 FFmpeg 推流命令
|
||
size = (1920, 1080)
|
||
size_str = f"{size[0]}x{size[1]}"
|
||
command = [
|
||
'ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo',
|
||
'-pix_fmt', 'bgr24', '-s', size_str, '-r', '30', '-probesize', '32', '-analyzeduration', '0',
|
||
'-i', '-', '-c:v', 'h264', '-pix_fmt', 'yuv420p',
|
||
'-preset', 'ultrafast', '-tune', 'zerolatency', '-f', 'flv',
|
||
'-g', '30', '-bufsize', '1000k', '-maxrate', '2000k',
|
||
'-x264opts', 'no-scenecut:keyint=30:min-keyint=30',
|
||
'-flvflags', 'no_duration_filesize', output_url
|
||
]
|
||
pipe = subprocess.Popen(command, stdin=subprocess.PIPE)
|
||
|
||
# 启动子线程
|
||
threads = []
|
||
threads.append(Thread(target=_read_frames, args=(cap, frame_q, stop_event), daemon=True))
|
||
threads.append(Thread(target=_process_frames, args=(frame_q, processed_q, cls, confidence, model_container, stop_event), daemon=True))
|
||
if push:
|
||
threads.append(Thread(target=_write_frames, args=(processed_q, pipe, size, stop_event), daemon=True))
|
||
|
||
if cls2:
|
||
threads.append(Thread(target=_cls2_find, args=(video_path, confidence, save_dir, stop_event), daemon=True))
|
||
|
||
for t in threads:
|
||
t.start()
|
||
print(f"[{video_path}] 开始处理流...")
|
||
|
||
# 等待停止信号
|
||
while not stop_event.is_set() and cap.isOpened():
|
||
time.sleep(0.1)
|
||
|
||
# 清理
|
||
print(f"[{video_path}] 停止中,正在清理...")
|
||
cap.release()
|
||
pipe.stdin.close()
|
||
pipe.wait()
|
||
for q in (frame_q, processed_q):
|
||
while not q.empty():
|
||
try: q.get_nowait()
|
||
except queue.Empty: break
|
||
print(f"[{video_path}] 会话结束。")
|