506 lines
17 KiB
Python
506 lines
17 KiB
Python
from sanic import Sanic, json, Blueprint,response
|
||
from sanic.exceptions import Unauthorized, NotFound
|
||
from sanic.response import json as json_response
|
||
from sanic_cors import CORS
|
||
from datetime import datetime
|
||
import logging
|
||
import uuid
|
||
import os
|
||
import asyncio
|
||
from minio import Minio
|
||
from ai_image import process_images # 你实现的图片处理函数
|
||
from queue import Queue
|
||
import gdal2tiles as gdal2tiles
|
||
from map_find import map_process_images
|
||
from yolo_train import auto_train
|
||
from map_cut import process_tiling
|
||
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
|
||
import torch
|
||
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
|
||
|
||
# 日志配置
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
###################################################################################验证中间件和管理件##############################################################################################
|
||
|
||
async def token_and_resource_check(request):
|
||
# --- Token 验证 ---
|
||
token = request.headers.get('X-API-Token')
|
||
expected_token = request.app.config.get("VALID_TOKEN")
|
||
if not token or token != expected_token:
|
||
logger.warning(f"Unauthorized request with token: {token}")
|
||
raise Unauthorized("Invalid token")
|
||
|
||
# --- GPU 使用率检查 ---
|
||
try:
|
||
if torch.cuda.is_available():
|
||
num_gpus = torch.cuda.device_count()
|
||
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
|
||
|
||
for i in range(num_gpus):
|
||
used = torch.cuda.memory_reserved(i)
|
||
total = torch.cuda.max_memory_reserved(i)
|
||
ratio = used / total if total else 0
|
||
|
||
logger.info(f"GPU {i} Usage: {ratio:.2%}")
|
||
|
||
if ratio > max_usage_ratio:
|
||
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
|
||
}, status=503)
|
||
except Exception as e:
|
||
logger.error(f"GPU check failed: {e}")
|
||
|
||
return None # 允许请求继续
|
||
|
||
##################################################################################################################################################################################################
|
||
#创建Sanic应用
|
||
app = Sanic("ai_Service_v2")
|
||
CORS(app) # 允许跨域请求
|
||
task_progress = {}
|
||
|
||
@app.middleware("request")
|
||
async def global_middleware(request):
|
||
result = await token_and_resource_check(request)
|
||
if result:
|
||
return result
|
||
|
||
# 配置Token和最大GPU使用率
|
||
app.config.update({
|
||
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
|
||
"MAX_GPU_USAGE": 0.9
|
||
})
|
||
|
||
######################################################################地图切割相关的API########################################################################################################
|
||
#创建地图的蓝图
|
||
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
|
||
app.blueprint(map_tile_blueprint)
|
||
|
||
|
||
@map_tile_blueprint.post("/tile")
|
||
async def map_tile_api(request):
|
||
try:
|
||
# 1. 检查请求体
|
||
if not request.json:
|
||
return json_response(
|
||
{"status": "error", "message": "Request body is required"},
|
||
status=400
|
||
)
|
||
|
||
# 2. 解析必要字段
|
||
tile_data = request.json
|
||
tif_url = tile_data.get("tif_url")
|
||
prj_url = tile_data.get("prj_url")
|
||
|
||
if not tif_url or not prj_url:
|
||
return json_response(
|
||
{"status": "error", "message": "Both tif_url and prj_url are required"},
|
||
status=400
|
||
)
|
||
|
||
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run)
|
||
zoom_levels = tile_data.get("zoom_levels", "1-18")
|
||
try:
|
||
# 假设 process_tiling 是一个协程函数(async def)
|
||
result = await process_tiling(tif_url, prj_url, zoom_levels)
|
||
|
||
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
|
||
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"data": result
|
||
})
|
||
except Exception as processing_error:
|
||
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
|
||
return json_response(
|
||
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
|
||
status=500
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||
return json_response(
|
||
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
|
||
status=500
|
||
)
|
||
|
||
|
||
#语义识别
|
||
@map_tile_blueprint.post("/uav")
|
||
async def process_handler(request):
|
||
"""
|
||
接口:/map/uav
|
||
输入 JSON:
|
||
{
|
||
"urls": [
|
||
"http://example.com/img1.jpg",
|
||
"http://example.com/img2.jpg"
|
||
],
|
||
"yaml_name": "112.44.103.230",
|
||
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
|
||
"bucket_directory": "2025/seg"
|
||
"model_path": "deeplabv3plus_best.pth"
|
||
}
|
||
输出 JSON:
|
||
{
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": [
|
||
"http://minio.example.com/uav-results/2025/seg/result1.png",
|
||
"http://minio.example.com/uav-results/2025/seg/result2.png"
|
||
]
|
||
}
|
||
|
||
"""
|
||
try:
|
||
body = request.json
|
||
urls = body.get("urls", [])
|
||
yaml_name = body.get("yaml_name")
|
||
bucket_name = body.get("bucket_name")
|
||
bucket_directory = body.get("bucket_directory")
|
||
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
|
||
|
||
# 校验参数
|
||
if not urls or not isinstance(urls, list):
|
||
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
|
||
if not all([yaml_name, bucket_name, bucket_directory]):
|
||
return json({"code": 400, "msg": "Missing required parameters"})
|
||
|
||
# 调用图像处理函数
|
||
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
|
||
return json(result)
|
||
|
||
except Exception as e:
|
||
return json({"code": 500, "msg": f"Server error: {str(e)}"})
|
||
|
||
######################################################################yolo相关的API########################################################################################################
|
||
#创建yolo的蓝图
|
||
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
|
||
app.blueprint(yolo_tile_blueprint)
|
||
|
||
# YOLO URL APT
|
||
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
|
||
|
||
@yolo_tile_blueprint.post("/process_images")
|
||
async def process_images(request):
|
||
"""
|
||
{
|
||
"urls": [
|
||
"http://example.com/image1.jpg",
|
||
"http://example.com/image2.jpg",
|
||
"http://example.com/image3.jpg"
|
||
],
|
||
"yaml_name": "your_minio_config",
|
||
"bucket_name": "my-bucket",
|
||
"bucket_directory": "2025/uav-results",
|
||
"model_path": "deeplabv3plus_best.pth"
|
||
}
|
||
"""
|
||
data = request.json
|
||
urls = data.get("urls")
|
||
yaml_name = data.get("yaml_name")
|
||
bucket_name = data.get("bucket_name")
|
||
bucket_directory = data.get("bucket_directory")
|
||
uav_model_path = data.get("uav_model_path")
|
||
|
||
if not urls or not yaml_name or not bucket_name or not uav_model_path:
|
||
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
|
||
|
||
task_id = str(uuid.uuid4())
|
||
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
|
||
|
||
# 启动后台任务
|
||
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
|
||
|
||
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
|
||
|
||
@yolo_tile_blueprint.get("/task_status/<task_id>")
|
||
async def task_status(request, task_id):
|
||
progress = task_progress.get(task_id)
|
||
if not progress:
|
||
return response.json({"code": 404, "msg": "Task not found"}, status=404)
|
||
return response.json({"code": 200, "msg": "Task status", "data": progress})
|
||
|
||
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
|
||
|
||
|
||
try:
|
||
task_progress[task_id]["status"] = "running"
|
||
task_progress[task_id]["progress"] = 10 # 开始进度
|
||
|
||
# 下载、推理、上传阶段分别更新进度
|
||
def progress_callback(stage, percent):
|
||
task_progress[task_id]["status"] = stage
|
||
task_progress[task_id]["progress"] = percent
|
||
|
||
result = await asyncio.to_thread(
|
||
map_process_images_with_progress,
|
||
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
|
||
)
|
||
|
||
task_progress[task_id]["status"] = "completed"
|
||
task_progress[task_id]["progress"] = 100
|
||
task_progress[task_id]["result"] = result
|
||
|
||
except Exception as e:
|
||
task_progress[task_id]["status"] = "failed"
|
||
task_progress[task_id]["progress"] = 100
|
||
task_progress[task_id]["result"] = str(e)
|
||
|
||
# YOLO检测API
|
||
@yolo_tile_blueprint.post("/picture")
|
||
async def yolo_detect_api(request):
|
||
try:
|
||
|
||
detect_data = request.json
|
||
|
||
# 解析必要字段
|
||
image_list = detect_data.get("image_list")
|
||
yolo_model = detect_data.get("yolo_model", "best.pt")
|
||
class_filter = detect_data.get("class", None)
|
||
minio_info = detect_data.get("minio", None)
|
||
|
||
if not image_list:
|
||
return json_response({"status": "error", "message": "image_list is required"}, status=400)
|
||
|
||
if not minio_info:
|
||
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
|
||
|
||
# 创建临时文件夹
|
||
input_folder = f"./temp_input_{str(uuid.uuid4())}"
|
||
output_folder = f"./temp_output_{str(uuid.uuid4())}"
|
||
|
||
# 执行图像处理
|
||
result = await asyncio.to_thread(
|
||
process_images,
|
||
yolo_model=yolo_model,
|
||
image_list=image_list,
|
||
class_filter=class_filter,
|
||
input_folder=input_folder,
|
||
output_folder=output_folder,
|
||
minio_info=minio_info
|
||
)
|
||
|
||
# 返回处理结果
|
||
return json_response(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"Internal server error: {str(e)}"
|
||
}, status=500)
|
||
|
||
# YOLO自动训练
|
||
@yolo_tile_blueprint.post("/train")
|
||
async def yolo_train_api(request):
|
||
"""
|
||
自动训练模型
|
||
输入 JSON:
|
||
{
|
||
"db_host": str,
|
||
"db_database": str,
|
||
"db_user": str,
|
||
"db_password": str,
|
||
"db_port": int,
|
||
"model_id": int,
|
||
"img_path": str,
|
||
"label_path": str,
|
||
"new_path": str,
|
||
"split_list": List[float],
|
||
"class_names": Optional[List[str]],
|
||
"project_name": str
|
||
}
|
||
输出 JSON:
|
||
{
|
||
"base_metrics": Dict[str, float],
|
||
"best_model_path": str,
|
||
"final_metrics": Dict[str, float]
|
||
}
|
||
|
||
"""
|
||
try:
|
||
# 修改为直接访问 request.json 而不是调用它
|
||
data = request.json
|
||
|
||
|
||
if not data:
|
||
return json_response({"status": "error", "message": "data is required"}, status=400)
|
||
|
||
# 执行图像处理
|
||
result = await asyncio.to_thread(
|
||
auto_train,
|
||
data
|
||
)
|
||
# 返回处理结果
|
||
return json_response(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"Internal server error: {str(e)}"
|
||
}, status=500)
|
||
|
||
###########################################################################################视频流相关的API#######################################################################################################
|
||
|
||
#创建视频流的蓝图
|
||
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
|
||
app.blueprint(stream_tile_blueprint)
|
||
|
||
#
|
||
# 任务管理器
|
||
class StreamTaskManager:
|
||
def __init__(self):
|
||
self.active_tasks = {}
|
||
self.task_status = {}
|
||
self.task_timestamps = {}
|
||
self.task_queue = Queue(maxsize=10)
|
||
|
||
def add_task(self, task_id: str, task_info: dict) -> None:
|
||
if self.task_queue.full():
|
||
oldest_task_id = self.task_queue.get()
|
||
self.remove_task(oldest_task_id)
|
||
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
|
||
self.active_tasks[task_id] = task_info
|
||
self.task_status[task_id] = "running"
|
||
self.task_timestamps[task_id] = datetime.now()
|
||
self.task_queue.put(task_id)
|
||
logger.info(f"Task {task_id} started")
|
||
|
||
def remove_task(self, task_id: str) -> None:
|
||
if task_id in self.active_tasks:
|
||
del self.active_tasks[task_id]
|
||
del self.task_status[task_id]
|
||
del self.task_timestamps[task_id]
|
||
logger.info(f"Task {task_id} removed")
|
||
|
||
def get_task_info(self, task_id: str) -> dict:
|
||
if task_id not in self.active_tasks:
|
||
raise NotFound("Task not found")
|
||
return {
|
||
"task_info": self.active_tasks[task_id],
|
||
"status": self.task_status[task_id],
|
||
"start_time": self.task_timestamps[task_id].isoformat()
|
||
}
|
||
|
||
task_manager = StreamTaskManager()
|
||
|
||
# ---------- API Endpoints ----------
|
||
|
||
@stream_tile_blueprint.post("/start")
|
||
async def api_start(request):
|
||
"""
|
||
启动视频流会话
|
||
输入 JSON:
|
||
{
|
||
"video_path": str,
|
||
"output_url": str,
|
||
"model_path": str,
|
||
"cls": List[int],
|
||
"confidence": float,
|
||
"cls2": Optional[List[int]]
|
||
"push": bool
|
||
}
|
||
输出 JSON:
|
||
{
|
||
"session_id": str,
|
||
"task_id": str,
|
||
"message": "started"
|
||
}
|
||
"""
|
||
data = request.json
|
||
task_id = str(uuid.uuid4())
|
||
|
||
# 启动视频处理会话,并传入 task_id
|
||
session_id = start_video_session(
|
||
video_path = data.get("video_path"),
|
||
output_url = data.get("output_url"),
|
||
model_path = data.get("model_path"),
|
||
cls = data.get("cls"),
|
||
confidence = data.get("confidence", 0.5),
|
||
cls2 = data.get("cls2", []),
|
||
push = data.get("push", False),
|
||
)
|
||
|
||
# 注册到任务管理器
|
||
task_manager.add_task(task_id, {
|
||
"session_id": session_id,
|
||
"video_path": data.get("video_path"),
|
||
"output_url": data.get("output_url"),
|
||
"model_path": data.get("model_path"),
|
||
"class_filter": data.get("cls", []),
|
||
"push": data.get("push", False),
|
||
"start_time": datetime.now().isoformat()
|
||
})
|
||
|
||
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
|
||
|
||
|
||
@stream_tile_blueprint.post("/stop")
|
||
async def api_stop(request):
|
||
"""
|
||
停止指定会话
|
||
输入 JSON: { "session_id": str }
|
||
输出 JSON: { "session_id": str, "message": "stopped" }
|
||
"""
|
||
session_id = request.json.get("session_id")
|
||
stop_video_session(session_id)
|
||
|
||
# 同步移除任务
|
||
for tid, info in list(task_manager.active_tasks.items()):
|
||
if info.get("session_id") == session_id:
|
||
task_manager.remove_task(tid)
|
||
break
|
||
|
||
return json({"session_id": session_id, "message": "stopped"})
|
||
|
||
|
||
@stream_tile_blueprint.post("/switch_model")
|
||
async def api_switch_model(request):
|
||
"""
|
||
切换会话模型
|
||
输入 JSON: { "session_id": str, "new_model_path": str }
|
||
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
|
||
"""
|
||
data = request.json
|
||
session_id = data.get("session_id")
|
||
new_model = data.get("new_model_path")
|
||
switch_model_session(session_id, new_model)
|
||
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
|
||
|
||
|
||
@stream_tile_blueprint.get("/sessions")
|
||
async def api_list_sessions(request):
|
||
"""
|
||
列出所有当前会话
|
||
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
|
||
"""
|
||
sessions = [
|
||
{"session_id": sid, "status": "running"}
|
||
for sid in stream_sessions.keys()
|
||
]
|
||
return json({"sessions": sessions})
|
||
|
||
|
||
# 统一的任务查询接口(含视频流)
|
||
@stream_tile_blueprint.get("/tasks")
|
||
async def api_list_tasks(request):
|
||
"""
|
||
列出所有任务(含状态、开始时间、详情)
|
||
"""
|
||
tasks = []
|
||
for tid in task_manager.active_tasks:
|
||
info = task_manager.get_task_info(tid)
|
||
tasks.append({"task_id": tid, **info})
|
||
return json({"tasks": tasks})
|
||
|
||
|
||
##################################################################################################################################################################################################
|
||
if __name__ == '__main__':
|
||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|