from sanic import Sanic, json, Blueprint,response from sanic.exceptions import Unauthorized, NotFound from sanic.response import json as json_response from sanic_cors import CORS from datetime import datetime import logging import uuid import os import asyncio from ai_image import process_images # 你实现的图片处理函数 from queue import Queue from map_find import map_process_images from yolo_train import auto_train from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions import torch from yolo_photo import map_process_images_with_progress # 引入你的处理函数 # 日志配置 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) ###################################################################################验证中间件和管理件############################################################################################## async def token_and_resource_check(request): # --- Token 验证 --- token = request.headers.get('X-API-Token') expected_token = request.app.config.get("VALID_TOKEN") if not token or token != expected_token: logger.warning(f"Unauthorized request with token: {token}") raise Unauthorized("Invalid token") # --- GPU 使用率检查 --- try: if torch.cuda.is_available(): num_gpus = torch.cuda.device_count() max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90% for i in range(num_gpus): used = torch.cuda.memory_reserved(i) total = torch.cuda.max_memory_reserved(i) ratio = used / total if total else 0 logger.info(f"GPU {i} Usage: {ratio:.2%}") if ratio > max_usage_ratio: logger.warning(f"GPU {i} usage too high: {ratio:.2%}") return json_response({ "status": "error", "message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later." }, status=503) except Exception as e: logger.error(f"GPU check failed: {e}") return None # 允许请求继续 ################################################################################################################################################################################################## #创建Sanic应用 app = Sanic("ai_Service_v2") CORS(app) # 允许跨域请求 task_progress = {} @app.middleware("request") async def global_middleware(request): result = await token_and_resource_check(request) if result: return result # 配置Token和最大GPU使用率 app.config.update({ "VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a", "MAX_GPU_USAGE": 0.9 }) ######################################################################地图切割相关的API######################################################################################################## #创建地图的蓝图 map_tile_blueprint = Blueprint('map', url_prefix='/map/') app.blueprint(map_tile_blueprint) #语义识别 @map_tile_blueprint.post("/uav") async def process_handler(request): """ 接口:/map/uav 输入 JSON: { "urls": [ "http://example.com/img1.jpg", "http://example.com/img2.jpg" ], "yaml_name": "config", "bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f", "bucket_directory": "2025/seg" "model_path": "deeplabv3plus_best.pth" } 输出 JSON: { "code": 200, "msg": "success", "data": [ "http://minio.example.com/uav-results/2025/seg/result1.png", "http://minio.example.com/uav-results/2025/seg/result2.png" ] } """ try: body = request.json urls = body.get("urls", []) yaml_name = body.get("yaml_name") bucket_name = body.get("bucket_name") bucket_directory = body.get("bucket_directory") model_path = os.path.join("map", "checkpoints", body.get("model_path")) # 校验参数 if not urls or not isinstance(urls, list): return json({"code": 400, "msg": "Missing or invalid 'urls'"}) if not all([yaml_name, bucket_name, bucket_directory]): return json({"code": 400, "msg": "Missing required parameters"}) # 调用图像处理函数 result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path) return json(result) except Exception as e: return json({"code": 500, "msg": f"Server error: {str(e)}"}) ######################################################################yolo相关的API######################################################################################################## #创建yolo的蓝图 yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/') app.blueprint(yolo_tile_blueprint) # YOLO URL APT # 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化) @yolo_tile_blueprint.post("/process_images") async def process_images(request): """ { "urls": [ "http://example.com/image1.jpg", "http://example.com/image2.jpg", "http://example.com/image3.jpg" ], "yaml_name": "your_minio_config", "bucket_name": "my-bucket", "bucket_directory": "2025/uav-results", "model_path": "deeplabv3plus_best.pth" } """ data = request.json urls = data.get("urls") yaml_name = data.get("yaml_name") bucket_name = data.get("bucket_name") bucket_directory = data.get("bucket_directory") uav_model_path = data.get("uav_model_path") if not urls or not yaml_name or not bucket_name or not uav_model_path: return response.json({"code": 400, "msg": "Missing parameters"}, status=400) task_id = str(uuid.uuid4()) task_progress[task_id] = {"status": "pending", "progress": 0, "result": None} # 启动后台任务 asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path)) return response.json({"code": 200, "msg": "Task started", "task_id": task_id}) @yolo_tile_blueprint.get("/task_status/") async def task_status(request, task_id): progress = task_progress.get(task_id) if not progress: return response.json({"code": 404, "msg": "Task not found"}, status=404) return response.json({"code": 200, "msg": "Task status", "data": progress}) async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path): try: task_progress[task_id]["status"] = "running" task_progress[task_id]["progress"] = 10 # 开始进度 # 下载、推理、上传阶段分别更新进度 def progress_callback(stage, percent): task_progress[task_id]["status"] = stage task_progress[task_id]["progress"] = percent result = await asyncio.to_thread( map_process_images_with_progress, urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback ) task_progress[task_id]["status"] = "completed" task_progress[task_id]["progress"] = 100 task_progress[task_id]["result"] = result except Exception as e: task_progress[task_id]["status"] = "failed" task_progress[task_id]["progress"] = 100 task_progress[task_id]["result"] = str(e) # YOLO检测API @yolo_tile_blueprint.post("/picture") async def yolo_detect_api(request): try: detect_data = request.json # 解析必要字段 image_list = detect_data.get("image_list") yolo_model = detect_data.get("yolo_model", "best.pt") class_filter = detect_data.get("class", None) minio_info = detect_data.get("minio", None) if not image_list: return json_response({"status": "error", "message": "image_list is required"}, status=400) if not minio_info: return json_response({"status": "error", "message": "MinIO information is required"}, status=400) # 创建临时文件夹 input_folder = f"./temp_input_{str(uuid.uuid4())}" output_folder = f"./temp_output_{str(uuid.uuid4())}" # 执行图像处理 result = await asyncio.to_thread( process_images, yolo_model=yolo_model, image_list=image_list, class_filter=class_filter, input_folder=input_folder, output_folder=output_folder, minio_info=minio_info ) # 返回处理结果 return json_response(result) except Exception as e: logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True) return json_response({ "status": "error", "message": f"Internal server error: {str(e)}" }, status=500) # YOLO自动训练 @yolo_tile_blueprint.post("/train") async def yolo_train_api(request): """ 自动训练模型 输入 JSON: { "db_host": str, "db_database": str, "db_user": str, "db_password": str, "db_port": int, "model_id": int, "img_path": str, "label_path": str, "new_path": str, "split_list": List[float], "class_names": Optional[List[str]], "project_name": str } 输出 JSON: { "base_metrics": Dict[str, float], "best_model_path": str, "final_metrics": Dict[str, float] } """ try: # 修改为直接访问 request.json 而不是调用它 data = request.json if not data: return json_response({"status": "error", "message": "data is required"}, status=400) # 执行图像处理 result = await asyncio.to_thread( auto_train, data ) # 返回处理结果 return json_response(result) except Exception as e: logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True) return json_response({ "status": "error", "message": f"Internal server error: {str(e)}" }, status=500) ###########################################################################################视频流相关的API####################################################################################################### #创建视频流的蓝图 stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/') app.blueprint(stream_tile_blueprint) # # 任务管理器 class StreamTaskManager: def __init__(self): self.active_tasks = {} self.task_status = {} self.task_timestamps = {} self.task_queue = Queue(maxsize=10) def add_task(self, task_id: str, task_info: dict) -> None: if self.task_queue.full(): oldest_task_id = self.task_queue.get() self.remove_task(oldest_task_id) stop_video_session(self.active_tasks[oldest_task_id]["session_id"]) self.active_tasks[task_id] = task_info self.task_status[task_id] = "running" self.task_timestamps[task_id] = datetime.now() self.task_queue.put(task_id) logger.info(f"Task {task_id} started") def remove_task(self, task_id: str) -> None: if task_id in self.active_tasks: del self.active_tasks[task_id] del self.task_status[task_id] del self.task_timestamps[task_id] logger.info(f"Task {task_id} removed") def get_task_info(self, task_id: str) -> dict: if task_id not in self.active_tasks: raise NotFound("Task not found") return { "task_info": self.active_tasks[task_id], "status": self.task_status[task_id], "start_time": self.task_timestamps[task_id].isoformat() } task_manager = StreamTaskManager() # ---------- API Endpoints ---------- @stream_tile_blueprint.post("/start") async def api_start(request): """ 启动视频流会话 输入 JSON: { "video_path": str, "output_url": str, "model_path": str, "cls": List[int], "confidence": float, "cls2": Optional[List[int]] "push": bool } 输出 JSON: { "session_id": str, "task_id": str, "message": "started" } """ data = request.json task_id = str(uuid.uuid4()) # 启动视频处理会话,并传入 task_id session_id = start_video_session( video_path = data.get("video_path"), output_url = data.get("output_url"), model_path = data.get("model_path"), cls = data.get("cls"), confidence = data.get("confidence", 0.5), cls2 = data.get("cls2", []), push = data.get("push", False), ) # 注册到任务管理器 task_manager.add_task(task_id, { "session_id": session_id, "video_path": data.get("video_path"), "output_url": data.get("output_url"), "model_path": data.get("model_path"), "class_filter": data.get("cls", []), "push": data.get("push", False), "start_time": datetime.now().isoformat() }) return json({"session_id": session_id, "task_id": task_id, "message": "started"}) @stream_tile_blueprint.post("/stop") async def api_stop(request): """ 停止指定会话 输入 JSON: { "session_id": str } 输出 JSON: { "session_id": str, "message": "stopped" } """ session_id = request.json.get("session_id") stop_video_session(session_id) # 同步移除任务 for tid, info in list(task_manager.active_tasks.items()): if info.get("session_id") == session_id: task_manager.remove_task(tid) break return json({"session_id": session_id, "message": "stopped"}) @stream_tile_blueprint.post("/switch_model") async def api_switch_model(request): """ 切换会话模型 输入 JSON: { "session_id": str, "new_model_path": str } 输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" } """ data = request.json session_id = data.get("session_id") new_model = data.get("new_model_path") switch_model_session(session_id, new_model) return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"}) @stream_tile_blueprint.get("/sessions") async def api_list_sessions(request): """ 列出所有当前会话 输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] } """ sessions = [ {"session_id": sid, "status": "running"} for sid in stream_sessions.keys() ] return json({"sessions": sessions}) # 统一的任务查询接口(含视频流) @stream_tile_blueprint.get("/tasks") async def api_list_tasks(request): """ 列出所有任务(含状态、开始时间、详情) """ tasks = [] for tid in task_manager.active_tasks: info = task_manager.get_task_info(tid) tasks.append({"task_id": tid, **info}) return json({"tasks": tasks}) ################################################################################################################################################################################################## if __name__ == '__main__': app.run(host="0.0.0.0", port=12366, debug=True,workers=1)