This commit is contained in:
yooooger 2025-08-01 14:50:49 +08:00
parent 8e2c1b8654
commit 0ea0d1a81a
13 changed files with 1304 additions and 4 deletions

View File

@ -0,0 +1 @@
[]

5
Ai_tottle/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": []
}

View File

@ -13,7 +13,6 @@ from map_find import map_process_images
from yolo_train import auto_train from yolo_train import auto_train
import torch import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数 from yolo_photo import map_process_images_with_progress # 引入你的处理函数
from tiles import TilesetProcessor
# 日志配置 # 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -30,7 +30,7 @@ def get_user_power(user_id: str, yaml_name: str) -> int | None:
conn.autocommit = True # 只读查询推荐设置 conn.autocommit = True # 只读查询推荐设置
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("SELECT power FROM outusers WHERE user_id = %s", (user_id,)) cur.execute(" WHERE user_id = %s", (user_id,))
row = cur.fetchone() row = cur.fetchone()
if row: if row:
return row[0] return row[0]

View File

@ -139,7 +139,7 @@ class TaskManager:
logger.warning("已将所有任务标记为停止状态") logger.warning("已将所有任务标记为停止状态")
app = Sanic("YoloStreamService") app = Sanic("YoloStreamServiceOut")
CORS(app) CORS(app)
task_manager = TaskManager() task_manager = TaskManager()
@ -613,4 +613,4 @@ if __name__ == "__main__":
print("正在安装psutil库...") print("正在安装psutil库...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True) app.run(host="0.0.0.0", port=12317, debug=False, access_log=True)

BIN
ai3/best.pt Normal file

Binary file not shown.

BIN
ai3/build.pt Normal file

Binary file not shown.

13
ai3/config.yaml Normal file
View File

@ -0,0 +1,13 @@
minio:
endpoint: "222.212.85.86:9000"
access_key: "adminjdskfj"
secret_key: "123456ksldjfal@Y"
secure: false
web: "http://222.212.85.86"
sql:
host: '222.212.85.86'
port: 5432
dbname: 'postgres'
user: 'postgres'
password: 'root'

562
ai3/cv_video.py Normal file
View File

@ -0,0 +1,562 @@
from threading import Thread, Lock, Event
import time
import queue
from ultralytics import YOLO # 导入 Ultralytics YOLO 模型
import os, cv2, torch, time, queue, subprocess
import numpy as np
# 全局变量
ifAI = {'status': False}
deskLock = Lock()
frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小
processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列
stop_event = Event()
def setIfAI(pb1):
deskLock.acquire()
ifAI['status'] = pb1
deskLock.release()
def getIfAI():
return ifAI['status']
def stopAIVideo():
print("正在停止AI视频处理...")
setIfAI(False)
stop_event.set()
# 等待足够长的时间确保资源释放
wait_count = 0
max_wait = 5 # 减少最大等待时间到5秒
while stop_event.is_set() and wait_count < max_wait:
time.sleep(0.5)
wait_count += 1
if wait_count >= max_wait:
print("警告: 停止AI视频处理超时强制终止")
# 不使用_thread._interrupt_main(),改用其他方式强制终止
try:
# 尝试终止可能运行的进程
import os
import signal
import psutil
# 查找并终止可能的FFmpeg进程
current_process = psutil.Process(os.getpid())
for child in current_process.children(recursive=True):
try:
child_name = child.name().lower()
if 'ffmpeg' in child_name:
print(f"正在终止子进程: {child.pid} ({child_name})")
child.send_signal(signal.SIGTERM)
except:
pass
except:
pass
else:
print("AI视频处理已停止")
def startAIVideo(video_path, output_url, m1, cls, confidence):
if ifAI['status']:
stopAIVideo()
time.sleep(1)
stop_event.clear()
thread = Thread(target=startAIVideo2,
args=(video_path, output_url, m1, cls, confidence))
# cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence))
# cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出
thread.daemon = True # 守护线程,主程序退出时线程也会退出
thread.start()
# cls2_thread.start()
def read_frames(cap, frame_queue):
"""优化的帧读取线程"""
frame_count = 0
last_time = time.time()
last_fps_time = time.time()
# 减小目标帧间隔时间,提高读取帧率
target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间提高到60fps
# 添加连接断开检测
connection_error_count = 0
max_connection_errors = 10 # 最多允许连续10次连接错误
last_successful_read = time.time()
max_read_wait = 30.0 # 30秒无法读取则认为连接断开
# 预先丢弃几帧,确保从新帧开始处理
for _ in range(5):
cap.grab()
while not stop_event.is_set():
current_time = time.time()
elapsed_time = current_time - last_time
# 检查是否长时间无法读取帧
if current_time - last_successful_read > max_read_wait:
print(f"警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开")
stop_event.set()
break
# 帧率控制,但更积极地读取
if elapsed_time < target_time_per_frame:
time.sleep(target_time_per_frame - elapsed_time)
continue
# 当队列快满时,跳过一些帧以避免延迟累积
if frame_queue.qsize() > frame_queue.maxsize * 0.8:
# 跳过一些帧
cap.grab()
last_time = time.time()
continue
ret, frame = cap.read()
if not ret:
print("拉流错误:无法读取帧")
connection_error_count += 1
if connection_error_count >= max_connection_errors:
print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...")
stop_event.set()
break
time.sleep(0.5) # 短暂等待后重试
continue
# 成功读取了帧,重置错误计数
connection_error_count = 0
last_successful_read = time.time()
frame_count += 1
if frame_count % 60 == 0: # 每60帧计算一次FPS
current_fps_time = time.time()
fps = 60 / (current_fps_time - last_fps_time)
print(f"拉流FPS: {fps:.2f}")
last_fps_time = current_fps_time
last_time = time.time()
frame_queue.put((frame, time.time())) # 添加时间戳
def process_frames(frame_queue, processed_frame_queue, ov_model, cls, confidence):
"""处理帧的线程,添加帧率控制"""
error_count = 0 # 添加错误计数器
max_errors = 5 # 最大容许错误次数
frame_count = 0
process_times = [] # 用于计算平均处理时间
# 设置YOLO模型配置提高性能
ov_model.conf = confidence # 设置置信度阈值
# 优化推理性能
try:
# 导入torch库
import torch
# 尝试启用ONNX Runtime加速
ov_model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
# 调整批处理大小为1减少内存占用
if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'):
ov_model.args.batch = 1
# 使用half精度提高性能
if torch.cuda.is_available() and hasattr(ov_model, 'model'):
try:
ov_model.model = ov_model.model.half()
except Exception as half_err:
print(f"半精度转换失败: {half_err}")
except Exception as e:
print(f"模型优化配置警告: {e}")
# 缓存先前的检测结果,用于提高稳定性
last_results = None
skip_counter = 0
max_skip = 2 # 最多跳过几帧不处理
while not stop_event.is_set():
try:
if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8:
# 如果输出队列接近满,等待一段时间
time.sleep(0.01)
continue
frame, timestamp = frame_queue.get(timeout=0.2)
# 处理延迟过大的帧
if time.time() - timestamp > 0.3: # 减少延迟阈值
continue
frame_count += 1
# 间隔采样每n帧处理一次减少计算量
if skip_counter > 0 and last_results is not None:
skip_counter -= 1
# 使用上次的检测结果
annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
processed_frame_queue.put((annotated_frame, timestamp))
continue
process_start = time.time()
# 动态调整处理尺寸,根据队列积压情况
resize_scale = 1.0
if frame_queue.qsize() > frame_queue.maxsize * 0.7:
resize_scale = 0.4 # 高负载时大幅降低分辨率
elif frame_queue.qsize() > frame_queue.maxsize * 0.5:
resize_scale = 0.6 # 中等负载时适度降低分辨率
elif frame_queue.qsize() > frame_queue.maxsize * 0.3:
resize_scale = 0.8 # 轻微负载时轻微降低分辨率
# 调整图像尺寸以加快处理
if resize_scale < 1.0:
process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale)
else:
process_frame = frame
# 执行推理
try:
results = ov_model(process_frame, classes=cls, show=False)
last_results = results[0] # 保存检测结果用于后续帧
# 如果尺寸调整过,需要将结果转换回原始尺寸
if resize_scale < 1.0:
# 绘制检测框
annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5),
(frame.shape[1], frame.shape[0]))
else:
annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5)
# 在负载高时启用跳帧处理
if frame_queue.qsize() > frame_queue.maxsize * 0.5:
skip_counter = max_skip
except Exception as infer_err:
print(f"推理错误: {infer_err}")
if last_results is not None:
# 使用上次的结果
annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5)
else:
# 如果没有上次的结果,简单返回原始帧
annotated_frame = frame.copy()
process_end = time.time()
process_times.append(process_end - process_start)
if len(process_times) > 30:
process_times.pop(0)
if frame_count % 30 == 0:
avg_process_time = sum(process_times) / len(process_times)
fps = 1.0 / avg_process_time if avg_process_time > 0 else 0
print(f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time*1000:.2f}ms, 队列大小: {frame_queue.qsize()}, 缩放比例: {resize_scale:.2f}")
processed_frame_queue.put((annotated_frame, timestamp))
error_count = 0 # 成功处理后重置错误计数
except queue.Empty:
continue
except Exception as e:
error_count += 1
print(f"处理帧错误: {e}")
if error_count >= max_errors:
print(f"连续处理错误达到{max_errors}次,正在停止处理...")
stop_event.set()
break
def write_frames(processed_frame_queue, pipe, size):
"""写入帧的线程,添加平滑处理"""
last_write_time = time.time()
target_time_per_frame = 1.0 / 30.0 # 30fps
pipe_error_count = 0 # 添加错误计数
max_pipe_errors = 3 # 最大容忍错误数
frame_count = 0
last_fps_time = time.time()
skipped_frames = 0
# 使用队列存储最近几帧,用于在需要时进行插值
recent_frames = []
max_recent_frames = 5 # 增加缓存帧数量,提高平滑性
# 使用双缓冲机制提高写入速度
buffer1 = bytearray(size[0] * size[1] * 3)
buffer2 = bytearray(size[0] * size[1] * 3)
current_buffer = buffer1
# 帧率控制参数
min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔
max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔
while not stop_event.is_set():
try:
# 获取处理后的帧,超时时间较短以便更平滑地处理
frame, timestamp = processed_frame_queue.get(timeout=0.05)
# 存储最近的帧用于插值
recent_frames.append(frame)
if len(recent_frames) > max_recent_frames:
recent_frames.pop(0)
current_time = time.time()
elapsed = current_time - last_write_time
# 如果两帧间隔太短,考虑合并或跳过
if elapsed < min_frame_interval and len(recent_frames) >= 2:
skipped_frames += 1
continue
# 如果两帧间隔太长,进行插值平滑
if elapsed > max_frame_interval and len(recent_frames) >= 2:
# 计算需要插入的帧数
frames_to_insert = min(int(elapsed / target_time_per_frame), 3)
for i in range(frames_to_insert):
# 创建插值帧
weight = (i + 1) / (frames_to_insert + 1)
interpolated_frame = cv2.addWeighted(recent_frames[-2], 1-weight, recent_frames[-1], weight, 0)
# 切换双缓冲
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
# 高效调整大小并写入
interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
img_bytes = interpolated_resized.tobytes()
# 写入管道
pipe.stdin.write(img_bytes)
pipe.stdin.flush()
# 切换双缓冲
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
# 正常写入当前帧 - 使用高效的调整大小方法
resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR)
img_bytes = resized_frame.tobytes()
pipe.stdin.write(img_bytes)
pipe.stdin.flush()
frame_count += 1
if frame_count % 30 == 0:
current_fps_time = time.time()
fps = 30 / (current_fps_time - last_fps_time)
print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}")
last_fps_time = current_fps_time
skipped_frames = 0
last_write_time = time.time()
pipe_error_count = 0 # 成功写入后重置错误计数
except queue.Empty:
# 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅
if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame:
try:
# 创建插值帧
interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0)
# 切换双缓冲
current_buffer = buffer2 if current_buffer is buffer1 else buffer1
resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR)
img_bytes = resized_frame.tobytes()
pipe.stdin.write(img_bytes)
pipe.stdin.flush()
last_write_time = time.time()
except Exception:
pass
continue
except Exception as e:
print(f"写入帧错误: {e}")
pipe_error_count += 1
if pipe_error_count >= max_pipe_errors:
print("FFmpeg管道错误过多正在终止进程...")
stop_event.set() # 主动结束所有线程
break
def startAIVideo2(video_path, output_url, m1, cls, confidence):
rtmp = output_url
setIfAI(True)
cap = None
pipe = None
read_thread = None
process_thread = None
write_thread = None
ov_model = None
try:
global frame_queue, processed_frame_queue, stop_event
stop_event = Event()
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
torch.cuda.empty_cache()
print("预加载YOLO模型...")
model_params = {}
try:
test_model = YOLO(m1)
if hasattr(test_model, "task"):
model_params["task"] = "detect"
if torch.cuda.is_available():
model_params["half"] = True
import inspect
if "verbose" in inspect.signature(YOLO.__init__).parameters:
model_params["verbose"] = False
except Exception as e:
print(f"参数检测失败,将使用默认参数: {e}")
model_params = {}
retry_count = 0
while retry_count < 3:
try:
ov_model = YOLO(m1, **model_params)
dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
for _ in range(3):
ov_model(dummy_frame, classes=cls, conf=confidence, show=False)
print("YOLO模型加载成功并预热完成")
break
except Exception as e:
retry_count += 1
print(f"YOLO模型加载失败(尝试 {retry_count}/3): {e}")
if "unexpected keyword" in str(e):
param = str(e).split("'")[-2]
if param in model_params:
print(f"移除不支持的参数: {param}")
del model_params[param]
time.sleep(2)
if ov_model is None:
raise Exception("无法加载YOLO模型")
ov_model.conf = confidence
print(f"正在连接视频流: {video_path}")
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 5)
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
if not cap.isOpened():
raise Exception(f"无法打开视频流: {video_path}")
try:
cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0)
except Exception as e:
print(f"无法设置自动曝光参数: {e}")
frame_queue = queue.Queue(maxsize=80)
processed_frame_queue = queue.Queue(maxsize=40)
size = (1920, 1080)
sizeStr = f"{size[0]}x{size[1]}"
command = [
'ffmpeg', '-y',
'-f', 'rawvideo', '-vcodec', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', sizeStr,
'-r', '30',
'-i', '-',
'-c:v', 'libx264',
'-preset', 'ultrafast',
'-tune', 'zerolatency',
'-x264-params', 'sei=0',
'-pix_fmt', 'yuv420p',
'-f', 'flv',
'-g', '30',
'-keyint_min', '30',
'-sc_threshold', '0',
'-b:v', '2500k',
'-maxrate', '3000k',
'-bufsize', '3000k',
'-threads', '4',
'-vsync', '1',
rtmp
]
print(f"启动FFmpeg推流到: {rtmp}")
pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
bufsize=10*1024*1024)
def monitor_ffmpeg_output(pipe):
while not stop_event.is_set():
line = pipe.stderr.readline().decode('utf-8', errors='ignore')
if line and ("error" in line.lower()):
print(f"FFmpeg错误: {line.strip()}")
if "Cannot open connection" in line:
stop_event.set()
break
Thread(target=monitor_ffmpeg_output, args=(pipe,), daemon=True).start()
read_thread = Thread(target=read_frames, args=(cap, frame_queue), daemon=True, name="ReadThread")
process_thread = Thread(target=process_frames, args=(frame_queue, processed_frame_queue, ov_model, cls, confidence), daemon=True, name="ProcessThread")
write_thread = Thread(target=write_frames, args=(processed_frame_queue, pipe, size), daemon=True, name="WriteThread")
print("开始推流处理...")
read_thread.start()
process_thread.start()
write_thread.start()
last_check = time.time()
while getIfAI() and not stop_event.is_set():
if not all([t.is_alive() for t in [read_thread, process_thread, write_thread]]):
print("检测到线程停止,退出")
stop_event.set()
break
if pipe.poll() is not None:
print("FFmpeg退出")
stop_event.set()
break
if time.time() - last_check > 30:
print(f"输入队列: {frame_queue.qsize()}/{frame_queue.maxsize} | 输出队列: {processed_frame_queue.qsize()}/{processed_frame_queue.maxsize}")
last_check = time.time()
time.sleep(0.1)
except Exception as e:
print(f"错误: {e}")
finally:
print("清理资源...")
stop_event.set()
setIfAI(False)
for t in [read_thread, process_thread, write_thread]:
if t and t.is_alive():
t.join(timeout=2)
try:
if cap: cap.release()
if pipe:
try:
import signal
os.kill(pipe.pid, signal.SIGTERM)
except: pass
pipe.stdin.close()
pipe.terminate()
try:
pipe.wait(timeout=2)
except:
pipe.kill()
except Exception as e:
print(f"释放资源时出错: {e}")
try:
cv2.destroyAllWindows()
except:
pass
print("资源释放完毕")
if __name__ == '__main__':
sn = "1581F6QAD243C00BP71E"
video_path = f"rtmp://222.212.85.86:1935/live/{sn}"
# FFmpeg 推流地址
rtmp = f"rtmp://222.212.85.86:1935/live/{sn}ai"
try:
startAIVideo2(video_path, rtmp, "best.pt", [0, 1, 2, 3, 4],0.4)
except KeyboardInterrupt:
print("程序被用户中断")
stopAIVideo()
except Exception as e:
print(f"程序异常: {e}")

48
ai3/minio_helper.py Normal file
View File

@ -0,0 +1,48 @@
import os
from minio import Minio
from minio.error import S3Error
bucket="300bdf2b-a150-406e-be63-d28bd29b409f"
# 替换为你的MinIO服务器地址、访问密钥和秘密密钥
def getClient():
minio_client = Minio(
"222.212.85.86:9000",
access_key="WuRenJi",
secure=False,
secret_key="WRJ@2024",)
return minio_client
def getPath2(object):
#dir="C:/sy/movies/"
dir=os.getcwd()+"/"
baseName=object
s1=baseName.rfind("/")
dir2=(dir+baseName[0:s1+1]).replace("/","\\")
fName=baseName[s1+1:int(len(baseName))]
os.makedirs(dir2, exist_ok=True)
file_path = os.path.join(dir2, fName)
return file_path
def upLoad(obj,path):
try:
minio_client=getClient()
minio_client.fput_object(bucket, obj, path)
return True
except S3Error as e:
return False
def downLoad(obj):
path=getPath2(obj)
if os.path.exists(path):
return path
# 从MinIO的存储桶和对象名称下载
try:
minio_client=getClient()
minio_client.fget_object(bucket, obj, path)
return path
except S3Error as e:
return ""
if __name__ == '__main__':
upLoad("aaa/yolo_api.py","yolo_api.py")

50
ai3/sqlhelp.py Normal file
View File

@ -0,0 +1,50 @@
import yaml
import psycopg2
def read_sql_config(yaml_name):
"""
读取 SQL 配置
"""
yaml_path = f"{yaml_name}.yaml"
with open(yaml_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
sql_config = config.get('sql')
if not sql_config:
raise ValueError("未找到 'sql' 配置块")
return sql_config
def get_user_power(user_id: str, yaml_name: str) -> int | None:
"""
根据 user_id 查询数据库中对应的 power找不到返回 None
"""
conn = None
try:
sql_config = read_sql_config(yaml_name)
conn = psycopg2.connect(
dbname=sql_config['dbname'],
user=sql_config['user'],
password=sql_config['password'],
host=sql_config['host'],
port=sql_config['port']
)
conn.autocommit = True # 只读查询推荐设置
with conn.cursor() as cur:
cur.execute("SELECT * FROM public.threduser where user_id=%s;", (user_id,))
row = cur.fetchone()
if row:
return row[3]
else:
return None
except Exception as e:
print(f"数据库操作异常: {e}")
return None
finally:
if conn:
conn.close()
if __name__ == '__main__':
user_id = '20250801'
yaml_name = 'config'
power = get_user_power(user_id, yaml_name)
print(power)

BIN
ai3/yolo11n.pt Normal file

Binary file not shown.

622
ai3/yolo_api.py Normal file
View File

@ -0,0 +1,622 @@
from sanic import Sanic, json
from sanic.response import json as json_response
from sanic.exceptions import Unauthorized, NotFound, SanicException
from dataclasses import dataclass
from typing import List, Dict, Any, Callable
import uuid
import logging
import asyncio
import traceback
from datetime import datetime
from cv_video import startAIVideo,stopAIVideo,getIfAI
from sanic_cors import CORS
from sqlhelp import get_user_power
import os
import signal
import psutil
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 服务状态标志
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
# 配置类
class Config:
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
MAX_ACTIVE_TASKS = 10
DEFAULT_CONFIDENCE = 0.5
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
@dataclass
class StreamRequest:
source_url: str
push_url: str
model_path: str
detect_classes: List[str]
user_id: str = None
confidence: float = Config.DEFAULT_CONFIDENCE
def validate(self) -> None:
"""验证请求参数"""
if not self.source_url or not self.push_url:
raise ValueError("Source URL and Push URL are required")
if not self.detect_classes:
raise ValueError("At least one detection class must be specified")
if not 0 < self.confidence < 1:
raise ValueError("Confidence must be between 0 and 1")
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest':
try:
instance = cls(
source_url=data['source_url'],
push_url=data['push_url'],
model_path=data['model_path'],
detect_classes=data['detect_classes'],
user_id=data['user_id'],
confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE)
)
instance.validate()
return instance
except KeyError as e:
raise ValueError(f"Missing required field: {str(e)}")
class TaskManager:
def __init__(self):
self.active_tasks: Dict[str, Dict[str, Any]] = {}
self.task_status: Dict[str, str] = {}
self.task_timestamps: Dict[str, datetime] = {}
self.stop_handles: Dict[str, Callable] = {} # 新增:每个任务的停止函数
self.task_callbacks: Dict[str, Callable] = {} # ✅ 添加这个属性,用于回调任务完成信息
def add_task(self, task_id: str, task_info: Dict[str, Any], stop_handle: Callable) -> None:
"""添加新任务"""
if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS:
raise ValueError("Maximum number of active tasks reached")
self.active_tasks[task_id] = task_info
self.task_status[task_id] = "running"
self.task_timestamps[task_id] = datetime.now()
self.stop_handles[task_id] = stop_handle # 注册停止函数
logger.info(f"Task {task_id} started")
def remove_task(self, task_id: str) -> None:
"""移除任务"""
if task_id in self.active_tasks:
del self.active_tasks[task_id]
del self.task_status[task_id]
del self.task_timestamps[task_id]
self.stop_handles.pop(task_id, None) # 同时移除停止句柄
logger.info(f"Task {task_id} removed")
def get_task_info(self, task_id: str) -> Dict[str, Any]:
"""获取任务信息"""
if task_id not in self.active_tasks:
raise NotFound("Task not found")
return {
"task_info": self.active_tasks[task_id],
"status": self.task_status[task_id],
"start_time": self.task_timestamps[task_id].isoformat()
}
def stop_task(self, task_id: str) -> bool:
"""只停止一个任务"""
if task_id not in self.active_tasks:
logger.warning(f"Task {task_id} 不存在")
return False
stop_handle = self.stop_handles.get(task_id)
if not stop_handle:
logger.warning(f"Task {task_id} 无法停止(未注册停止函数)")
return False
try:
stop_handle() # 执行停止函数(你需传入能关闭 FFmpeg 或线程的回调)
self.task_status[task_id] = "stopped"
logger.info(f"Task {task_id} 停止成功")
return True
except Exception as e:
logger.error(f"停止任务 {task_id} 出错:{e}")
self.task_status[task_id] = "error"
return False
def check_tasks_health(self) -> Dict[str, str]:
"""检查任务健康状态"""
unhealthy_tasks = {}
for task_id in list(self.active_tasks.keys()):
# 检查任务是否还在运行通过getIfAI()函数)
if not getIfAI():
unhealthy_tasks[task_id] = "stopped"
logger.warning(f"Task {task_id} appears to be stopped unexpectedly")
return unhealthy_tasks
def mark_all_tasks_as_stopped(self):
"""标记所有任务为已停止状态"""
for task_id in list(self.active_tasks.keys()):
self.task_status[task_id] = "stopped"
logger.warning("已将所有任务标记为停止状态")
app = Sanic("YoloStreamServiceOut")
CORS(app)
task_manager = TaskManager()
async def safe_stop_ai_video():
"""安全地停止AI视频处理带有错误处理和恢复机制"""
try:
await asyncio.to_thread(stopAIVideo)
return True
except Exception as e:
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
# 标记服务状态为不健康
service_status["is_healthy"] = False
service_status["last_error"] = str(e)
service_status["error_time"] = datetime.now().isoformat()
# 强制结束所有任务
task_manager.mark_all_tasks_as_stopped()
# 尝试通过其他方式杀死可能存在的进程
try:
current_process = psutil.Process(os.getpid())
# 查找并终止ffmpeg子进程
for child in current_process.children(recursive=True):
try:
child_name = child.name().lower()
if 'ffmpeg' in child_name:
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
child.send_signal(signal.SIGTERM)
except Exception as child_e:
logger.error(f"终止子进程出错: {str(child_e)}")
except Exception as kill_e:
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
# 等待一段时间让系统恢复
await asyncio.sleep(Config.RESTART_DELAY)
# 重置服务状态
service_status["is_healthy"] = True
return False
def verify_token(request) -> None:
"""验证请求token"""
token = request.headers.get('X-API-Token')
if not token or token != Config.VALID_TOKEN:
logger.warning("Invalid token attempt")
raise Unauthorized("Invalid token")
async def verify_userid(request) -> None:
"""验证请求userid"""
# 解析并验证请求数据
stream_request = StreamRequest.from_dict(request.json)
userid = stream_request.user_id
print(userid)
if not userid:
logger.warning("userid not define")
raise Unauthorized("Invalid userid")
if get_user_power(userid,"config") < 1:
logger.warning("user have not power")
raise Unauthorized("user have not power")
async def detection(request):
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
service_status["is_healthy"] = True
# 停止所有现有任务
for task_id in list(task_manager.active_tasks.keys()):
logger.info(f"停止现有任务 {task_id} 以启动新任务")
try:
stop_cb = task_manager.task_callbacks.get(task_id, safe_stop_ai_video)
success = await stop_cb()
task_manager.remove_task(task_id)
except Exception as e:
logger.error(f"停止任务时出错: {e}")
task_manager.mark_all_tasks_as_stopped()
# 解析请求参数
stream_request = StreamRequest.from_dict(request.json)
task_id = str(uuid.uuid4())
# 替换地址
# new_source_url = stream_request.source_url.replace("222.212.85.86", "192.168.10.5")
# new_push_url = stream_request.push_url.replace("222.212.85.86", "192.168.10.5")
try:
await asyncio.to_thread(
startAIVideo,
stream_request.source_url,
stream_request.push_url,
stream_request.model_path,
stream_request.detect_classes,
stream_request.confidence
)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
# ✅ 注册 stop_callback如你使用的为通用函数
task_manager.add_task(
task_id,
{
"source_url": stream_request.source_url,
"push_url": stream_request.push_url,
"model_path": stream_request.model_path,
"detect_classes": stream_request.detect_classes,
"confidence": stream_request.confidence
},
stop_handle=safe_stop_ai_video # 修正参数名
)
return json_response({
"status": "success",
"task_id": task_id,
"message": "Detection started successfully"
})
@app.post("/ai/stream/detect1")
async def start_detection1(request):
try:
await verify_userid(request)
try:
verify_token(request)
return await detection(request)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/stream/detect")
async def start_detection(request):
try:
verify_token(request)
detection(request)
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/stream/<task_id>")
async def stop_detection(request, task_id: str):
try:
verify_token(request)
# 检查任务是否存在
try:
task_info = task_manager.get_task_info(task_id)
except NotFound:
return json_response({"status": "error", "message": "Task not found"}, status=404)
# 调用 task_callbacks 中的停止函数(如果有)
stop_callback = task_manager.task_callbacks.get(task_id)
if stop_callback:
success = await stop_callback()
else:
logger.warning(f"Task {task_id} has no stop callback, using default safe_stop_ai_video")
success = await safe_stop_ai_video()
# 无论成功与否都要移除任务
task_manager.remove_task(task_id)
if not success:
logger.warning(f"停止任务 {task_id} 失败,但已移除任务记录")
return json_response({
"status": "warning",
"message": "Task removal completed with warnings"
})
return json_response({
"status": "success",
"message": f"Detection for task {task_id} stopped successfully"
})
except NotFound as e:
return json_response({"status": "error", "message": str(e)}, status=404)
except Exception as e:
logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True)
# 尝试标记任务为停止状态
try:
if task_id in task_manager.task_status:
task_manager.task_status[task_id] = "error_during_stop"
except:
pass
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/stream/stopTask1")
async def stopTask1(request):
try:
verify_userid(request)
try:
verify_token(request)
jsondata = await request.json()
task_id = jsondata.get("task_id")
if not task_id:
return json_response({"status": "error", "message": "task_id is required"}, status=400)
try:
task_info = task_manager.get_task_info(task_id)
logger.info(f"Stopping task: {task_id} -> {task_info}")
# 调用对应任务的停止回调
stop_callback = task_manager.task_callbacks.get(task_id)
if stop_callback:
success = await stop_callback()
else:
logger.warning(f"No stop callback found for task {task_id}")
success = False
task_manager.remove_task(task_id)
if not success:
return json_response({
"status": "warning",
"message": "Task removal completed, but stop failed"
})
return json_response({
"status": "success",
"message": "Task stopped successfully"
})
except NotFound:
return json_response({"status": "error", "message": "Task not found"}, status=404)
except Exception as e:
logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True)
try:
if task_id in task_manager.task_status:
task_manager.task_status[task_id] = "error_during_stop"
except:
pass
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
except:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.get("/ai/stream/<task_id>")
async def get_task_status(request, task_id: str):
try:
verify_token(request)
task_info = task_manager.get_task_info(task_id)
# 检查任务是否真的在运行
if not getIfAI() and task_info["status"] == "running":
task_info["status"] = "stopped_unexpectedly"
logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止")
return json_response({
"status": "success",
"task_id": task_id,
**task_info
})
except NotFound as e:
return json_response({"status": "error", "message": str(e)}, status=404)
except Exception as e:
logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.get("/ai/stream/tasks")
async def list_tasks(request):
"""获取所有活动任务列表"""
try:
verify_token(request)
# 检查所有任务的健康状态
unhealthy_tasks = task_manager.check_tasks_health()
for task_id, status in unhealthy_tasks.items():
if task_id in task_manager.task_status:
task_manager.task_status[task_id] = status
tasks = {
task_id: task_manager.get_task_info(task_id)
for task_id in task_manager.active_tasks.keys()
}
return json_response({
"status": "success",
"tasks": tasks
})
except Exception as e:
logger.error(f"Error listing tasks: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/stream/stopTasks")
async def stop_all_detections(request):
"""停止所有活动任务"""
try:
verify_token(request)
if not task_manager.active_tasks:
return json_response({
"status": "success",
"message": "No active tasks to stop"
})
# 停止所有任务
success = await safe_stop_ai_video()
# 无论成功与否,都移除所有任务
for task_id in list(task_manager.active_tasks.keys()):
task_manager.remove_task(task_id)
if not success:
return json_response({
"status": "warning",
"message": "Tasks stopped with warnings"
})
return json_response({
"status": "success",
"message": "All detections stopped successfully"
})
except Exception as e:
logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True)
# 尝试标记所有任务为停止状态
task_manager.mark_all_tasks_as_stopped()
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.get("/ai/health")
async def health_check(request):
"""服务健康检查端点"""
try:
# 不需要验证token这个接口可以用于监控系统检查服务状态
unhealthy_tasks = task_manager.check_tasks_health()
return json_response({
"status": "success",
"service": "running" if service_status["is_healthy"] else "degraded",
"active_tasks": len(task_manager.active_tasks),
"unhealthy_tasks": unhealthy_tasks,
"last_error": service_status["last_error"],
"error_time": service_status["error_time"],
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"健康检查失败: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"service": "degraded",
"error": str(e),
"timestamp": datetime.now().isoformat()
}, status=500)
@app.route("/ai/reset", methods=["POST"])
async def reset_service(request):
"""重置服务状态,清理所有任务和进程"""
try:
verify_token(request)
# 尝试停止AI视频处理
await safe_stop_ai_video()
# 清理所有任务
for task_id in list(task_manager.active_tasks.keys()):
task_manager.remove_task(task_id)
# 重置服务状态
service_status["is_healthy"] = True
service_status["last_error"] = None
service_status["error_time"] = None
# 尝试清理可能存在的僵尸进程
try:
import os
import signal
import psutil
current_process = psutil.Process(os.getpid())
zombie_count = 0
for child in current_process.children(recursive=True):
try:
if child.status() == psutil.STATUS_ZOMBIE:
zombie_count += 1
child.send_signal(signal.SIGKILL)
except:
pass
return json_response({
"status": "success",
"message": f"Service reset successful. Cleaned {zombie_count} zombie processes."
})
except Exception as e:
logger.error(f"清理僵尸进程时出错: {e}")
return json_response({
"status": "warning",
"message": "Service reset with warnings"
})
except Exception as e:
logger.error(f"重置服务时出错: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Failed to reset service: {str(e)}"
}, status=500)
@app.route("/ai/stream/restart/<task_id>", methods=["POST"])
async def restart_task(request, task_id: str):
"""重启指定任务"""
try:
verify_token(request)
# 获取任务信息
try:
task_info = task_manager.get_task_info(task_id)["task_info"]
except NotFound:
return json_response({"status": "error", "message": "Task not found"}, status=404)
# 先停止任务
success = await safe_stop_ai_video()
task_manager.remove_task(task_id)
if not success:
logger.warning("停止任务出现问题,尝试继续重启")
# 重新启动任务
new_task_id = str(uuid.uuid4())
try:
await asyncio.to_thread(
startAIVideo,
task_info["source_url"],
task_info["push_url"],
task_info["model_path"],
task_info["detect_classes"],
task_info["confidence"]
)
# 记录新任务信息
task_manager.add_task(new_task_id, task_info)
return json_response({
"status": "success",
"old_task_id": task_id,
"new_task_id": new_task_id,
"message": "Task restarted successfully"
})
except Exception as e:
logger.error(f"重启任务失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to restart task: {str(e)}"
}, status=500)
except Exception as e:
logger.error(f"重启任务时出错: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
if __name__ == "__main__":
# 保证服务启动前没有残留任务
try:
stopAIVideo()
print("服务启动前清理完成")
except:
print("服务启动前清理失败,但仍将继续")
# 安装psutil库用于进程管理
try:
import psutil
except ImportError:
import subprocess
import sys
print("正在安装psutil库...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
app.run(host="0.0.0.0", port=12317, debug=False, access_log=True)