AI_python_yoooger/Ai_tottle/ai_tottle_api.py
2025-09-02 11:08:02 +08:00

321 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sanic import Sanic, json, Blueprint,response
from sanic.exceptions import Unauthorized
from sanic.response import json as json_response
from sanic_cors import CORS
import numpy as np
import logging
import uuid
import os,traceback
import asyncio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
from map_find import map_process_images
from yolo_train import auto_train,query_progress
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check(request):
# --- Token 验证 ---
token = request.headers.get('X-API-Token')
expected_token = request.app.config.get("VALID_TOKEN")
if not token or token != expected_token:
logger.warning(f"Unauthorized request with token: {token}")
raise Unauthorized("Invalid token")
# --- GPU 使用率检查 ---
try:
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
for i in range(num_gpus):
used = torch.cuda.memory_reserved(i)
total = torch.cuda.max_memory_reserved(i)
ratio = used / total if total else 0
logger.info(f"GPU {i} Usage: {ratio:.2%}")
if ratio > max_usage_ratio:
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
return json_response({
"status": "error",
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
}, status=503)
except Exception as e:
logger.error(f"GPU check failed: {e}")
return None # 允许请求继续
##################################################################################################################################################################################################
#创建Sanic应用
app = Sanic("ai_Service_v2")
CORS(app) # 允许跨域请求
task_progress = {}
@app.middleware("request")
async def global_middleware(request):
result = await token_and_resource_check(request)
if result:
return result
# 配置Token和最大GPU使用率
app.config.update({
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
"MAX_GPU_USAGE": 0.9
})
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
app.blueprint(map_tile_blueprint)
#语义识别
@map_tile_blueprint.post("/uav")
async def process_handler(request):
"""
接口:/map/uav
输入 JSON:
{
"urls": [
"http://example.com/img1.jpg",
"http://example.com/img2.jpg"
],
"yaml_name": "config",
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
"bucket_directory": "2025/seg"
"model_path": "deeplabv3plus_best.pth"
}
输出 JSON:
{
"code": 200,
"msg": "success",
"data": [
"http://minio.example.com/uav-results/2025/seg/result1.png",
"http://minio.example.com/uav-results/2025/seg/result2.png"
]
}
"""
try:
body = request.json
urls = body.get("urls", [])
yaml_name = body.get("yaml_name")
bucket_name = body.get("bucket_name")
bucket_directory = body.get("bucket_directory")
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
# 校验参数
if not urls or not isinstance(urls, list):
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
if not all([yaml_name, bucket_name, bucket_directory]):
return json({"code": 400, "msg": "Missing required parameters"})
# 调用图像处理函数
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
return json(result)
except Exception as e:
return json({"code": 500, "msg": f"Server error: {str(e)}"})
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
app.blueprint(yolo_tile_blueprint)
# YOLO URL APT
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
@yolo_tile_blueprint.post("/process_images")
async def process_images(request):
"""
{
"urls": [
"http://example.com/image1.jpg",
"http://example.com/image2.jpg",
"http://example.com/image3.jpg"
],
"yaml_name": "your_minio_config",
"bucket_name": "my-bucket",
"bucket_directory": "2025/uav-results",
"model_path": "deeplabv3plus_best.pth"
}
"""
data = request.json
urls = data.get("urls")
yaml_name = data.get("yaml_name")
bucket_name = data.get("bucket_name")
bucket_directory = data.get("bucket_directory")
uav_model_path = data.get("uav_model_path")
if not urls or not yaml_name or not bucket_name or not uav_model_path:
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
task_id = str(uuid.uuid4())
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
# 启动后台任务
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id):
progress = task_progress.get(task_id)
if not progress:
return response.json({"code": 404, "msg": "Task not found"}, status=404)
return response.json({"code": 200, "msg": "Task status", "data": progress})
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
try:
task_progress[task_id]["status"] = "running"
task_progress[task_id]["progress"] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback(stage, percent):
task_progress[task_id]["status"] = stage
task_progress[task_id]["progress"] = percent
result = await asyncio.to_thread(
map_process_images_with_progress,
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
)
task_progress[task_id]["status"] = "completed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = result
except Exception as e:
task_progress[task_id]["status"] = "failed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = str(e)
# YOLO检测API
@yolo_tile_blueprint.post("/picture")
async def yolo_detect_api(request):
try:
detect_data = request.json
# 解析必要字段
image_list = detect_data.get("image_list")
yolo_model = detect_data.get("yolo_model", "best.pt")
class_filter = detect_data.get("class", None)
minio_info = detect_data.get("minio", None)
if not image_list:
return json_response({"status": "error", "message": "image_list is required"}, status=400)
if not minio_info:
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
# 创建临时文件夹
input_folder = f"./temp_input_{str(uuid.uuid4())}"
output_folder = f"./temp_output_{str(uuid.uuid4())}"
# 执行图像处理
result = await asyncio.to_thread(
process_images,
yolo_model=yolo_model,
image_list=image_list,
class_filter=class_filter,
input_folder=input_folder,
output_folder=output_folder,
minio_info=minio_info
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO自动训练
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
自动训练模型
输入 JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
输出 JSON:
return {
"status": "success",
"message": "Train finished",
"project_name": project_name,
"label_count": label_count,
"base_metrics": base_metrics,
"final_metrics": final_metrics
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# 执行图像处理
result = await asyncio.to_thread(
auto_train,
data
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# 查询训练进度接口
@yolo_tile_blueprint.get("/progress/<project_name>")
async def yolo_train_progress(request, project_name):
'''
输入参数:
如果想查询最新一次训练GET /yolo/progress/my_project
如果想查询某次特定时间GET /yolo/progress/my_project?run_time=20250902_1012
输出 JSON:
{
"status": "ok",
"run_time": "20250902_1012",
"progress": {
"epoch": 12,
"precision": 0.72,
"recall": 0.64,
"mAP50": 0.68,
"mAP50-95": 0.42
}
}
'''
run_time = request.args.get("run_time") # 可选参数
result = await asyncio.to_thread(query_progress, project_name, run_time)
return json_response(result)
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)