ai_project_v1/ai_tottle_api.py

506 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sanic import Sanic, json, Blueprint,response
from sanic.exceptions import Unauthorized, NotFound
from sanic.response import json as json_response
from sanic_cors import CORS
from datetime import datetime
import logging
import uuid
import os
import asyncio
from minio import Minio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
import gdal2tiles as gdal2tiles
from map_find import map_process_images
from yolo_train import auto_train
from map_cut import process_tiling
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check(request):
# --- Token 验证 ---
token = request.headers.get('X-API-Token')
expected_token = request.app.config.get("VALID_TOKEN")
if not token or token != expected_token:
logger.warning(f"Unauthorized request with token: {token}")
raise Unauthorized("Invalid token")
# --- GPU 使用率检查 ---
try:
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
for i in range(num_gpus):
used = torch.cuda.memory_reserved(i)
total = torch.cuda.max_memory_reserved(i)
ratio = used / total if total else 0
logger.info(f"GPU {i} Usage: {ratio:.2%}")
if ratio > max_usage_ratio:
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
return json_response({
"status": "error",
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
}, status=503)
except Exception as e:
logger.error(f"GPU check failed: {e}")
return None # 允许请求继续
##################################################################################################################################################################################################
#创建Sanic应用
app = Sanic("ai_Service_v2")
CORS(app) # 允许跨域请求
task_progress = {}
@app.middleware("request")
async def global_middleware(request):
result = await token_and_resource_check(request)
if result:
return result
# 配置Token和最大GPU使用率
app.config.update({
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
"MAX_GPU_USAGE": 0.9
})
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
app.blueprint(map_tile_blueprint)
@map_tile_blueprint.post("/tile")
async def map_tile_api(request):
try:
# 1. 检查请求体
if not request.json:
return json_response(
{"status": "error", "message": "Request body is required"},
status=400
)
# 2. 解析必要字段
tile_data = request.json
tif_url = tile_data.get("tif_url")
prj_url = tile_data.get("prj_url")
if not tif_url or not prj_url:
return json_response(
{"status": "error", "message": "Both tif_url and prj_url are required"},
status=400
)
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run
zoom_levels = tile_data.get("zoom_levels", "1-18")
try:
# 假设 process_tiling 是一个协程函数async def
result = await process_tiling(tif_url, prj_url, zoom_levels)
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
return json_response({
"status": "success",
"data": result
})
except Exception as processing_error:
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
return json_response(
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
status=500
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response(
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
status=500
)
#语义识别
@map_tile_blueprint.post("/uav")
async def process_handler(request):
"""
接口:/map/uav
输入 JSON:
{
"urls": [
"http://example.com/img1.jpg",
"http://example.com/img2.jpg"
],
"yaml_name": "112.44.103.230",
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
"bucket_directory": "2025/seg"
"model_path": "deeplabv3plus_best.pth"
}
输出 JSON:
{
"code": 200,
"msg": "success",
"data": [
"http://minio.example.com/uav-results/2025/seg/result1.png",
"http://minio.example.com/uav-results/2025/seg/result2.png"
]
}
"""
try:
body = request.json
urls = body.get("urls", [])
yaml_name = body.get("yaml_name")
bucket_name = body.get("bucket_name")
bucket_directory = body.get("bucket_directory")
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
# 校验参数
if not urls or not isinstance(urls, list):
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
if not all([yaml_name, bucket_name, bucket_directory]):
return json({"code": 400, "msg": "Missing required parameters"})
# 调用图像处理函数
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
return json(result)
except Exception as e:
return json({"code": 500, "msg": f"Server error: {str(e)}"})
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
app.blueprint(yolo_tile_blueprint)
# YOLO URL APT
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
@yolo_tile_blueprint.post("/process_images")
async def process_images(request):
"""
{
"urls": [
"http://example.com/image1.jpg",
"http://example.com/image2.jpg",
"http://example.com/image3.jpg"
],
"yaml_name": "your_minio_config",
"bucket_name": "my-bucket",
"bucket_directory": "2025/uav-results",
"model_path": "deeplabv3plus_best.pth"
}
"""
data = request.json
urls = data.get("urls")
yaml_name = data.get("yaml_name")
bucket_name = data.get("bucket_name")
bucket_directory = data.get("bucket_directory")
uav_model_path = data.get("uav_model_path")
if not urls or not yaml_name or not bucket_name or not uav_model_path:
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
task_id = str(uuid.uuid4())
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
# 启动后台任务
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id):
progress = task_progress.get(task_id)
if not progress:
return response.json({"code": 404, "msg": "Task not found"}, status=404)
return response.json({"code": 200, "msg": "Task status", "data": progress})
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
try:
task_progress[task_id]["status"] = "running"
task_progress[task_id]["progress"] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback(stage, percent):
task_progress[task_id]["status"] = stage
task_progress[task_id]["progress"] = percent
result = await asyncio.to_thread(
map_process_images_with_progress,
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
)
task_progress[task_id]["status"] = "completed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = result
except Exception as e:
task_progress[task_id]["status"] = "failed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = str(e)
# YOLO检测API
@yolo_tile_blueprint.post("/picture")
async def yolo_detect_api(request):
try:
detect_data = request.json
# 解析必要字段
image_list = detect_data.get("image_list")
yolo_model = detect_data.get("yolo_model", "best.pt")
class_filter = detect_data.get("class", None)
minio_info = detect_data.get("minio", None)
if not image_list:
return json_response({"status": "error", "message": "image_list is required"}, status=400)
if not minio_info:
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
# 创建临时文件夹
input_folder = f"./temp_input_{str(uuid.uuid4())}"
output_folder = f"./temp_output_{str(uuid.uuid4())}"
# 执行图像处理
result = await asyncio.to_thread(
process_images,
yolo_model=yolo_model,
image_list=image_list,
class_filter=class_filter,
input_folder=input_folder,
output_folder=output_folder,
minio_info=minio_info
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO自动训练
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
自动训练模型
输入 JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
输出 JSON:
{
"base_metrics": Dict[str, float],
"best_model_path": str,
"final_metrics": Dict[str, float]
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# 执行图像处理
result = await asyncio.to_thread(
auto_train,
data
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
###########################################################################################视频流相关的API#######################################################################################################
#创建视频流的蓝图
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
app.blueprint(stream_tile_blueprint)
#
# 任务管理器
class StreamTaskManager:
def __init__(self):
self.active_tasks = {}
self.task_status = {}
self.task_timestamps = {}
self.task_queue = Queue(maxsize=10)
def add_task(self, task_id: str, task_info: dict) -> None:
if self.task_queue.full():
oldest_task_id = self.task_queue.get()
self.remove_task(oldest_task_id)
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
self.active_tasks[task_id] = task_info
self.task_status[task_id] = "running"
self.task_timestamps[task_id] = datetime.now()
self.task_queue.put(task_id)
logger.info(f"Task {task_id} started")
def remove_task(self, task_id: str) -> None:
if task_id in self.active_tasks:
del self.active_tasks[task_id]
del self.task_status[task_id]
del self.task_timestamps[task_id]
logger.info(f"Task {task_id} removed")
def get_task_info(self, task_id: str) -> dict:
if task_id not in self.active_tasks:
raise NotFound("Task not found")
return {
"task_info": self.active_tasks[task_id],
"status": self.task_status[task_id],
"start_time": self.task_timestamps[task_id].isoformat()
}
task_manager = StreamTaskManager()
# ---------- API Endpoints ----------
@stream_tile_blueprint.post("/start")
async def api_start(request):
"""
启动视频流会话
输入 JSON:
{
"video_path": str,
"output_url": str,
"model_path": str,
"cls": List[int],
"confidence": float,
"cls2": Optional[List[int]]
"push": bool
}
输出 JSON:
{
"session_id": str,
"task_id": str,
"message": "started"
}
"""
data = request.json
task_id = str(uuid.uuid4())
# 启动视频处理会话,并传入 task_id
session_id = start_video_session(
video_path = data.get("video_path"),
output_url = data.get("output_url"),
model_path = data.get("model_path"),
cls = data.get("cls"),
confidence = data.get("confidence", 0.5),
cls2 = data.get("cls2", []),
push = data.get("push", False),
)
# 注册到任务管理器
task_manager.add_task(task_id, {
"session_id": session_id,
"video_path": data.get("video_path"),
"output_url": data.get("output_url"),
"model_path": data.get("model_path"),
"class_filter": data.get("cls", []),
"push": data.get("push", False),
"start_time": datetime.now().isoformat()
})
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
@stream_tile_blueprint.post("/stop")
async def api_stop(request):
"""
停止指定会话
输入 JSON: { "session_id": str }
输出 JSON: { "session_id": str, "message": "stopped" }
"""
session_id = request.json.get("session_id")
stop_video_session(session_id)
# 同步移除任务
for tid, info in list(task_manager.active_tasks.items()):
if info.get("session_id") == session_id:
task_manager.remove_task(tid)
break
return json({"session_id": session_id, "message": "stopped"})
@stream_tile_blueprint.post("/switch_model")
async def api_switch_model(request):
"""
切换会话模型
输入 JSON: { "session_id": str, "new_model_path": str }
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
"""
data = request.json
session_id = data.get("session_id")
new_model = data.get("new_model_path")
switch_model_session(session_id, new_model)
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
@stream_tile_blueprint.get("/sessions")
async def api_list_sessions(request):
"""
列出所有当前会话
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
"""
sessions = [
{"session_id": sid, "status": "running"}
for sid in stream_sessions.keys()
]
return json({"sessions": sessions})
# 统一的任务查询接口(含视频流)
@stream_tile_blueprint.get("/tasks")
async def api_list_tasks(request):
"""
列出所有任务(含状态、开始时间、详情)
"""
tasks = []
for tid in task_manager.active_tasks:
info = task_manager.get_task_info(tid)
tasks.append({"task_id": tid, **info})
return json({"tasks": tasks})
##################################################################################################################################################################################################
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)