380 lines
13 KiB
Python
380 lines
13 KiB
Python
import logging,os,uuid,asyncio,torch
|
|
# sanic imports
|
|
from sanic import Sanic, json, Blueprint,response
|
|
from sanic.exceptions import Unauthorized
|
|
from sanic.response import json as json_response
|
|
from sanic_cors import CORS
|
|
# ourself imports
|
|
from ai_image import process_images
|
|
from map_find import map_process_images
|
|
from yolo_train import train_main
|
|
from yolo_photo import map_process_images_with_progress
|
|
from pydantic import BaseModel, ValidationError
|
|
from typing import List, Dict
|
|
import threading
|
|
import torch
|
|
import uuid
|
|
from queue import Queue
|
|
|
|
|
|
# set up logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
###################################################################################验证中间件和管理件##############################################################################################
|
|
|
|
async def token_and_resource_check(request):
|
|
# --- Token 验证 ---
|
|
token = request.headers.get('X-API-Token')
|
|
expected_token = request.app.config.get("VALID_TOKEN")
|
|
if not token or token != expected_token:
|
|
logger.warning(f"Unauthorized request with token: {token}")
|
|
raise Unauthorized("Invalid token")
|
|
|
|
# --- GPU 使用率检查 ---
|
|
try:
|
|
if torch.cuda.is_available():
|
|
num_gpus = torch.cuda.device_count()
|
|
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
|
|
|
|
for i in range(num_gpus):
|
|
used = torch.cuda.memory_reserved(i)
|
|
total = torch.cuda.max_memory_reserved(i)
|
|
ratio = used / total if total else 0
|
|
|
|
logger.info(f"GPU {i} Usage: {ratio:.2%}")
|
|
|
|
if ratio > max_usage_ratio:
|
|
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
|
|
return json_response({
|
|
"status": "error",
|
|
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
|
|
}, status=503)
|
|
except Exception as e:
|
|
logger.error(f"GPU check failed: {e}")
|
|
|
|
return None # 允许请求继续
|
|
|
|
################################################################# set up app and blueprints ########################################################################################################
|
|
# create app and cors
|
|
app = Sanic("ai_Service_v2")
|
|
CORS(app) # 允许跨域请求
|
|
task_progress = {}
|
|
|
|
@app.middleware("request")
|
|
async def global_middleware(request):
|
|
result = await token_and_resource_check(request)
|
|
if result:
|
|
return result
|
|
|
|
# 配置Token和最大GPU使用率
|
|
app.config.update({
|
|
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
|
|
"MAX_GPU_USAGE": 0.9
|
|
})
|
|
|
|
######################################################################地图切割相关的API########################################################################################################
|
|
#创建地图的蓝图
|
|
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
|
|
app.blueprint(map_tile_blueprint)
|
|
|
|
#语义识别
|
|
@map_tile_blueprint.post("/uav")
|
|
async def process_handler(request):
|
|
"""
|
|
接口:/map/uav
|
|
输入 JSON:
|
|
{
|
|
"urls": [
|
|
"http://example.com/img1.jpg",
|
|
"http://example.com/img2.jpg"
|
|
],
|
|
"yaml_name": "config",
|
|
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
|
|
"bucket_directory": "2025/seg"
|
|
"model_path": "deeplabv3plus_best.pth"
|
|
}
|
|
输出 JSON:
|
|
{
|
|
"code": 200,
|
|
"msg": "success",
|
|
"data": [
|
|
"http://minio.example.com/uav-results/2025/seg/result1.png",
|
|
"http://minio.example.com/uav-results/2025/seg/result2.png"
|
|
]
|
|
}
|
|
|
|
"""
|
|
try:
|
|
body = request.json
|
|
urls = body.get("urls", [])
|
|
yaml_name = body.get("yaml_name")
|
|
bucket_name = body.get("bucket_name")
|
|
bucket_directory = body.get("bucket_directory")
|
|
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
|
|
|
|
# 校验参数
|
|
if not urls or not isinstance(urls, list):
|
|
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
|
|
if not all([yaml_name, bucket_name, bucket_directory]):
|
|
return json({"code": 400, "msg": "Missing required parameters"})
|
|
|
|
# 调用图像处理函数
|
|
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
|
|
return json(result)
|
|
|
|
except Exception as e:
|
|
return json({"code": 500, "msg": f"Server error: {str(e)}"})
|
|
|
|
######################################################################yolo相关的API########################################################################################################
|
|
#创建yolo的蓝图
|
|
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
|
|
app.blueprint(yolo_tile_blueprint)
|
|
|
|
# YOLO URL APT
|
|
# save the task progress in memory
|
|
|
|
@yolo_tile_blueprint.post("/process_images")
|
|
async def process_images(request):
|
|
"""
|
|
{
|
|
"urls": [
|
|
"http://example.com/image1.jpg",
|
|
"http://example.com/image2.jpg",
|
|
"http://example.com/image3.jpg"
|
|
],
|
|
"yaml_name": "your_minio_config",
|
|
"bucket_name": "my-bucket",
|
|
"bucket_directory": "2025/uav-results",
|
|
"model_path": "deeplabv3plus_best.pth"
|
|
}
|
|
"""
|
|
data = request.json
|
|
urls = data.get("urls")
|
|
yaml_name = data.get("yaml_name")
|
|
bucket_name = data.get("bucket_name")
|
|
bucket_directory = data.get("bucket_directory")
|
|
uav_model_path = data.get("uav_model_path")
|
|
|
|
if not urls or not yaml_name or not bucket_name or not uav_model_path:
|
|
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
|
|
|
|
task_id = str(uuid.uuid4())
|
|
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
|
|
|
|
# 启动后台任务
|
|
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
|
|
|
|
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
|
|
|
|
@yolo_tile_blueprint.get("/task_status/<task_id>")
|
|
async def task_status(request, task_id):
|
|
progress = task_progress.get(task_id)
|
|
if not progress:
|
|
return response.json({"code": 404, "msg": "Task not found"}, status=404)
|
|
return response.json({"code": 200, "msg": "Task status", "data": progress})
|
|
|
|
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
|
|
|
|
|
|
try:
|
|
task_progress[task_id]["status"] = "running"
|
|
task_progress[task_id]["progress"] = 10 # 开始进度
|
|
|
|
# 下载、推理、上传阶段分别更新进度
|
|
def progress_callback(stage, percent):
|
|
task_progress[task_id]["status"] = stage
|
|
task_progress[task_id]["progress"] = percent
|
|
|
|
result = await asyncio.to_thread(
|
|
map_process_images_with_progress,
|
|
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
|
|
)
|
|
|
|
task_progress[task_id]["status"] = "completed"
|
|
task_progress[task_id]["progress"] = 100
|
|
task_progress[task_id]["result"] = result
|
|
|
|
except Exception as e:
|
|
task_progress[task_id]["status"] = "failed"
|
|
task_progress[task_id]["progress"] = 100
|
|
task_progress[task_id]["result"] = str(e)
|
|
|
|
# YOLO detect API
|
|
@yolo_tile_blueprint.post("/picture")
|
|
async def yolo_detect_api(request):
|
|
try:
|
|
|
|
detect_data = request.json
|
|
|
|
#
|
|
image_list = detect_data.get("image_list")
|
|
yolo_model = detect_data.get("yolo_model", "best.pt")
|
|
class_filter = detect_data.get("class", None)
|
|
minio_info = detect_data.get("minio", None)
|
|
|
|
if not image_list:
|
|
return json_response({"status": "error", "message": "image_list is required"}, status=400)
|
|
|
|
if not minio_info:
|
|
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
|
|
|
|
# Create a temporary directory for input and output images
|
|
input_folder = f"./temp_input_{str(uuid.uuid4())}"
|
|
output_folder = f"./temp_output_{str(uuid.uuid4())}"
|
|
|
|
# Execute the image processing in a separate thread to avoid blocking the event loop
|
|
result = await asyncio.to_thread(
|
|
process_images,
|
|
yolo_model=yolo_model,
|
|
image_list=image_list,
|
|
class_filter=class_filter,
|
|
input_folder=input_folder,
|
|
output_folder=output_folder,
|
|
minio_info=minio_info
|
|
)
|
|
|
|
# return the result as JSON response
|
|
return json_response(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
|
return json_response({
|
|
"status": "error",
|
|
"message": f"Internal server error: {str(e)}"
|
|
}, status=500)
|
|
|
|
#--------------------------------------------------------------------------yolo训练相关的API----------------------------------------------------------------########################################
|
|
#创建yolo训练的蓝图
|
|
|
|
MAX_CONCURRENT_JOBS = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
|
tasks: Dict[str, Dict] = {}
|
|
task_queue = Queue()
|
|
active_jobs: List[str] = []
|
|
lock = threading.Lock()
|
|
|
|
|
|
# ------------------ 参数模型 ------------------
|
|
class TrainRequest(BaseModel):
|
|
config_name: str
|
|
table_name: str
|
|
column_name: str
|
|
search_condition: str
|
|
aim_path: str
|
|
image_dir: str
|
|
label_dir: str
|
|
output_path: str
|
|
pt_path: str
|
|
imgsz: int
|
|
epochs: int
|
|
device: List[int]
|
|
hsv_v: float
|
|
cos_lr: bool
|
|
batch: int
|
|
project_dir: str
|
|
class_names: List[str]
|
|
|
|
|
|
# ------------------ 核心执行函数 ------------------
|
|
def run_training(task_id: str, params: TrainRequest):
|
|
try:
|
|
with lock:
|
|
active_jobs.append(task_id)
|
|
tasks[task_id]["status"] = "running"
|
|
|
|
train_main(
|
|
config_name=params.config_name,
|
|
table_name=params.table_name,
|
|
column_name=params.column_name,
|
|
search_condition=params.search_condition,
|
|
aim_path=params.aim_path,
|
|
image_dir=params.image_dir,
|
|
label_dir=params.label_dir,
|
|
output_path=params.output_path,
|
|
pt_path=params.pt_path,
|
|
imgsz=params.imgsz,
|
|
epochs=params.epochs,
|
|
device=params.device,
|
|
hsv_v=params.hsv_v,
|
|
cos_lr=params.cos_lr,
|
|
batch=params.batch,
|
|
project_dir=params.project_dir,
|
|
class_names=params.class_names
|
|
)
|
|
|
|
tasks[task_id]["status"] = "finished"
|
|
except Exception as e:
|
|
tasks[task_id]["status"] = "failed"
|
|
tasks[task_id]["error"] = str(e)
|
|
finally:
|
|
with lock:
|
|
if task_id in active_jobs:
|
|
active_jobs.remove(task_id)
|
|
schedule_next_job()
|
|
|
|
|
|
# ------------------ 调度器 ------------------
|
|
def schedule_next_job():
|
|
with lock:
|
|
while len(active_jobs) < MAX_CONCURRENT_JOBS and not task_queue.empty():
|
|
next_id = task_queue.get()
|
|
params = tasks[next_id]["params"]
|
|
t = threading.Thread(target=run_training, args=(next_id, params), daemon=True)
|
|
t.start()
|
|
# ------------------ 接口 ------------------
|
|
@yolo_tile_blueprint.post("/train")
|
|
async def submit_train_job(request):
|
|
try:
|
|
data = request.json
|
|
params = TrainRequest(**data)
|
|
except ValidationError as e:
|
|
return json({"success": False, "error": e.errors()})
|
|
|
|
task_id = str(uuid.uuid4())
|
|
tasks[task_id] = {"status": "queued", "params": params}
|
|
|
|
with lock:
|
|
if len(active_jobs) < MAX_CONCURRENT_JOBS:
|
|
t = threading.Thread(target=run_training, args=(task_id, params), daemon=True)
|
|
t.start()
|
|
else:
|
|
task_queue.put(task_id)
|
|
tasks[task_id]["status"] = "waiting"
|
|
|
|
return json({"success": True, "task_id": task_id, "message": "任务已提交"})
|
|
|
|
|
|
@yolo_tile_blueprint.get("/task_status/<task_id>")
|
|
async def task_status(request, task_id: str):
|
|
if task_id not in tasks:
|
|
return json({"success": False, "message": "任务ID不存在"})
|
|
|
|
task_info = tasks[task_id]
|
|
return json({
|
|
"success": True,
|
|
"status": task_info["status"],
|
|
"error": task_info.get("error", None)
|
|
})
|
|
|
|
|
|
@yolo_tile_blueprint.get("/tasks")
|
|
async def all_tasks(request):
|
|
return json({
|
|
tid: {"status": info["status"]}
|
|
for tid, info in tasks.items()
|
|
})
|
|
|
|
|
|
@yolo_tile_blueprint.get("/system_status")
|
|
async def system_status(request):
|
|
gpu_available = torch.cuda.is_available()
|
|
return json({
|
|
"gpu_available": gpu_available,
|
|
"max_concurrent": MAX_CONCURRENT_JOBS,
|
|
"running_jobs": len(active_jobs),
|
|
"waiting_jobs": task_queue.qsize(),
|
|
"active_task_ids": active_jobs
|
|
})
|
|
if __name__ == '__main__':
|
|
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
|
|