AI_python_yoooger/Ai_tottle/ai_tottle_api.py
2025-08-01 14:50:49 +08:00

296 lines
10 KiB
Python

from sanic import Sanic, json, Blueprint,response
from sanic.exceptions import Unauthorized
from sanic.response import json as json_response
from sanic_cors import CORS
import numpy as np
import logging
import uuid
import os,traceback
import asyncio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
from map_find import map_process_images
from yolo_train import auto_train
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check(request):
# --- Token 验证 ---
token = request.headers.get('X-API-Token')
expected_token = request.app.config.get("VALID_TOKEN")
if not token or token != expected_token:
logger.warning(f"Unauthorized request with token: {token}")
raise Unauthorized("Invalid token")
# --- GPU 使用率检查 ---
try:
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
for i in range(num_gpus):
used = torch.cuda.memory_reserved(i)
total = torch.cuda.max_memory_reserved(i)
ratio = used / total if total else 0
logger.info(f"GPU {i} Usage: {ratio:.2%}")
if ratio > max_usage_ratio:
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
return json_response({
"status": "error",
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
}, status=503)
except Exception as e:
logger.error(f"GPU check failed: {e}")
return None # 允许请求继续
##################################################################################################################################################################################################
#创建Sanic应用
app = Sanic("ai_Service_v2")
CORS(app) # 允许跨域请求
task_progress = {}
@app.middleware("request")
async def global_middleware(request):
result = await token_and_resource_check(request)
if result:
return result
# 配置Token和最大GPU使用率
app.config.update({
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
"MAX_GPU_USAGE": 0.9
})
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
app.blueprint(map_tile_blueprint)
#语义识别
@map_tile_blueprint.post("/uav")
async def process_handler(request):
"""
接口:/map/uav
输入 JSON:
{
"urls": [
"http://example.com/img1.jpg",
"http://example.com/img2.jpg"
],
"yaml_name": "config",
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
"bucket_directory": "2025/seg"
"model_path": "deeplabv3plus_best.pth"
}
输出 JSON:
{
"code": 200,
"msg": "success",
"data": [
"http://minio.example.com/uav-results/2025/seg/result1.png",
"http://minio.example.com/uav-results/2025/seg/result2.png"
]
}
"""
try:
body = request.json
urls = body.get("urls", [])
yaml_name = body.get("yaml_name")
bucket_name = body.get("bucket_name")
bucket_directory = body.get("bucket_directory")
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
# 校验参数
if not urls or not isinstance(urls, list):
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
if not all([yaml_name, bucket_name, bucket_directory]):
return json({"code": 400, "msg": "Missing required parameters"})
# 调用图像处理函数
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
return json(result)
except Exception as e:
return json({"code": 500, "msg": f"Server error: {str(e)}"})
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
app.blueprint(yolo_tile_blueprint)
# YOLO URL APT
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
@yolo_tile_blueprint.post("/process_images")
async def process_images(request):
"""
{
"urls": [
"http://example.com/image1.jpg",
"http://example.com/image2.jpg",
"http://example.com/image3.jpg"
],
"yaml_name": "your_minio_config",
"bucket_name": "my-bucket",
"bucket_directory": "2025/uav-results",
"model_path": "deeplabv3plus_best.pth"
}
"""
data = request.json
urls = data.get("urls")
yaml_name = data.get("yaml_name")
bucket_name = data.get("bucket_name")
bucket_directory = data.get("bucket_directory")
uav_model_path = data.get("uav_model_path")
if not urls or not yaml_name or not bucket_name or not uav_model_path:
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
task_id = str(uuid.uuid4())
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
# 启动后台任务
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id):
progress = task_progress.get(task_id)
if not progress:
return response.json({"code": 404, "msg": "Task not found"}, status=404)
return response.json({"code": 200, "msg": "Task status", "data": progress})
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
try:
task_progress[task_id]["status"] = "running"
task_progress[task_id]["progress"] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback(stage, percent):
task_progress[task_id]["status"] = stage
task_progress[task_id]["progress"] = percent
result = await asyncio.to_thread(
map_process_images_with_progress,
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
)
task_progress[task_id]["status"] = "completed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = result
except Exception as e:
task_progress[task_id]["status"] = "failed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = str(e)
# YOLO检测API
@yolo_tile_blueprint.post("/picture")
async def yolo_detect_api(request):
try:
detect_data = request.json
# 解析必要字段
image_list = detect_data.get("image_list")
yolo_model = detect_data.get("yolo_model", "best.pt")
class_filter = detect_data.get("class", None)
minio_info = detect_data.get("minio", None)
if not image_list:
return json_response({"status": "error", "message": "image_list is required"}, status=400)
if not minio_info:
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
# 创建临时文件夹
input_folder = f"./temp_input_{str(uuid.uuid4())}"
output_folder = f"./temp_output_{str(uuid.uuid4())}"
# 执行图像处理
result = await asyncio.to_thread(
process_images,
yolo_model=yolo_model,
image_list=image_list,
class_filter=class_filter,
input_folder=input_folder,
output_folder=output_folder,
minio_info=minio_info
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO自动训练
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
自动训练模型
输入 JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
输出 JSON:
{
"base_metrics": Dict[str, float],
"best_model_path": str,
"final_metrics": Dict[str, float]
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# 执行图像处理
result = await asyncio.to_thread(
auto_train,
data
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)