2025-09-02 11:50:56 +08:00
import logging , os , uuid , asyncio , torch
# sanic imports
2025-07-10 09:41:26 +08:00
from sanic import Sanic , json , Blueprint , response
2025-07-24 14:03:12 +08:00
from sanic . exceptions import Unauthorized
2025-07-10 09:41:26 +08:00
from sanic . response import json as json_response
from sanic_cors import CORS
2025-09-02 11:50:56 +08:00
# ourself imports
from ai_image import process_images
2025-07-10 09:41:26 +08:00
from map_find import map_process_images
2025-11-11 09:43:25 +08:00
from yolo_train import train_main
2025-09-02 11:50:56 +08:00
from yolo_photo import map_process_images_with_progress
2025-11-11 09:43:25 +08:00
from pydantic import BaseModel , ValidationError
from typing import List , Dict
import threading
import torch
import uuid
from queue import Queue
2025-09-02 11:50:56 +08:00
# set up logging
2025-07-10 09:41:26 +08:00
logging . basicConfig ( level = logging . INFO , format = ' %(asctime)s - %(name)s - %(levelname)s - %(message)s ' )
logger = logging . getLogger ( __name__ )
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check ( request ) :
# --- Token 验证 ---
token = request . headers . get ( ' X-API-Token ' )
expected_token = request . app . config . get ( " VALID_TOKEN " )
if not token or token != expected_token :
logger . warning ( f " Unauthorized request with token: { token } " )
raise Unauthorized ( " Invalid token " )
# --- GPU 使用率检查 ---
try :
if torch . cuda . is_available ( ) :
2025-11-11 09:43:25 +08:00
num_gpus = torch . cuda . device_count ( )
2025-07-10 09:41:26 +08:00
max_usage_ratio = request . app . config . get ( " MAX_GPU_USAGE " , 0.9 ) # 默认90%
for i in range ( num_gpus ) :
used = torch . cuda . memory_reserved ( i )
total = torch . cuda . max_memory_reserved ( i )
ratio = used / total if total else 0
logger . info ( f " GPU { i } Usage: { ratio : .2% } " )
if ratio > max_usage_ratio :
logger . warning ( f " GPU { i } usage too high: { ratio : .2% } " )
return json_response ( {
" status " : " error " ,
" message " : f " GPU resource busy (GPU { i } at { ratio : .2% } ). Try later. "
} , status = 503 )
except Exception as e :
logger . error ( f " GPU check failed: { e } " )
return None # 允许请求继续
2025-09-02 11:50:56 +08:00
################################################################# set up app and blueprints ########################################################################################################
# create app and cors
2025-07-10 09:41:26 +08:00
app = Sanic ( " ai_Service_v2 " )
CORS ( app ) # 允许跨域请求
task_progress = { }
@app.middleware ( " request " )
async def global_middleware ( request ) :
result = await token_and_resource_check ( request )
if result :
return result
# 配置Token和最大GPU使用率
app . config . update ( {
" VALID_TOKEN " : " Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a " ,
" MAX_GPU_USAGE " : 0.9
} )
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint ( ' map ' , url_prefix = ' /map/ ' )
app . blueprint ( map_tile_blueprint )
#语义识别
@map_tile_blueprint.post ( " /uav " )
async def process_handler ( request ) :
"""
接口 : / map / uav
输入 JSON :
{
" urls " : [
" http://example.com/img1.jpg " ,
" http://example.com/img2.jpg "
] ,
2025-07-10 10:04:45 +08:00
" yaml_name " : " config " ,
2025-07-10 09:41:26 +08:00
" bucket_name " : " 300bdf2b-a150-406e-be63-d28bd29b409f " ,
" bucket_directory " : " 2025/seg "
" model_path " : " deeplabv3plus_best.pth "
}
输出 JSON :
{
" code " : 200 ,
" msg " : " success " ,
" data " : [
" http://minio.example.com/uav-results/2025/seg/result1.png " ,
" http://minio.example.com/uav-results/2025/seg/result2.png "
]
}
"""
try :
body = request . json
urls = body . get ( " urls " , [ ] )
yaml_name = body . get ( " yaml_name " )
bucket_name = body . get ( " bucket_name " )
bucket_directory = body . get ( " bucket_directory " )
model_path = os . path . join ( " map " , " checkpoints " , body . get ( " model_path " ) )
# 校验参数
if not urls or not isinstance ( urls , list ) :
return json ( { " code " : 400 , " msg " : " Missing or invalid ' urls ' " } )
if not all ( [ yaml_name , bucket_name , bucket_directory ] ) :
return json ( { " code " : 400 , " msg " : " Missing required parameters " } )
# 调用图像处理函数
result = map_process_images ( urls , yaml_name , bucket_name , bucket_directory , model_path )
return json ( result )
except Exception as e :
return json ( { " code " : 500 , " msg " : f " Server error: { str ( e ) } " } )
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint ( ' yolo ' , url_prefix = ' /yolo/ ' )
app . blueprint ( yolo_tile_blueprint )
# YOLO URL APT
2025-09-23 15:47:28 +08:00
# save the task progress in memory
2025-07-10 09:41:26 +08:00
@yolo_tile_blueprint.post ( " /process_images " )
async def process_images ( request ) :
"""
{
" urls " : [
" http://example.com/image1.jpg " ,
" http://example.com/image2.jpg " ,
" http://example.com/image3.jpg "
] ,
" yaml_name " : " your_minio_config " ,
" bucket_name " : " my-bucket " ,
" bucket_directory " : " 2025/uav-results " ,
" model_path " : " deeplabv3plus_best.pth "
}
"""
data = request . json
urls = data . get ( " urls " )
yaml_name = data . get ( " yaml_name " )
bucket_name = data . get ( " bucket_name " )
bucket_directory = data . get ( " bucket_directory " )
uav_model_path = data . get ( " uav_model_path " )
if not urls or not yaml_name or not bucket_name or not uav_model_path :
return response . json ( { " code " : 400 , " msg " : " Missing parameters " } , status = 400 )
task_id = str ( uuid . uuid4 ( ) )
task_progress [ task_id ] = { " status " : " pending " , " progress " : 0 , " result " : None }
# 启动后台任务
asyncio . create_task ( run_image_processing ( task_id , urls , yaml_name , bucket_name , bucket_directory , uav_model_path ) )
return response . json ( { " code " : 200 , " msg " : " Task started " , " task_id " : task_id } )
@yolo_tile_blueprint.get ( " /task_status/<task_id> " )
async def task_status ( request , task_id ) :
progress = task_progress . get ( task_id )
if not progress :
return response . json ( { " code " : 404 , " msg " : " Task not found " } , status = 404 )
return response . json ( { " code " : 200 , " msg " : " Task status " , " data " : progress } )
async def run_image_processing ( task_id , urls , yaml_name , bucket_name , bucket_directory , uav_model_path ) :
try :
task_progress [ task_id ] [ " status " ] = " running "
task_progress [ task_id ] [ " progress " ] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback ( stage , percent ) :
task_progress [ task_id ] [ " status " ] = stage
task_progress [ task_id ] [ " progress " ] = percent
result = await asyncio . to_thread (
map_process_images_with_progress ,
urls , yaml_name , bucket_name , bucket_directory , uav_model_path , progress_callback
)
task_progress [ task_id ] [ " status " ] = " completed "
task_progress [ task_id ] [ " progress " ] = 100
task_progress [ task_id ] [ " result " ] = result
except Exception as e :
task_progress [ task_id ] [ " status " ] = " failed "
task_progress [ task_id ] [ " progress " ] = 100
task_progress [ task_id ] [ " result " ] = str ( e )
2025-09-02 11:50:56 +08:00
# YOLO detect API
2025-07-10 09:41:26 +08:00
@yolo_tile_blueprint.post ( " /picture " )
async def yolo_detect_api ( request ) :
try :
detect_data = request . json
2025-09-02 11:50:56 +08:00
#
2025-07-10 09:41:26 +08:00
image_list = detect_data . get ( " image_list " )
yolo_model = detect_data . get ( " yolo_model " , " best.pt " )
class_filter = detect_data . get ( " class " , None )
minio_info = detect_data . get ( " minio " , None )
if not image_list :
return json_response ( { " status " : " error " , " message " : " image_list is required " } , status = 400 )
if not minio_info :
return json_response ( { " status " : " error " , " message " : " MinIO information is required " } , status = 400 )
2025-09-02 11:50:56 +08:00
# Create a temporary directory for input and output images
2025-07-10 09:41:26 +08:00
input_folder = f " ./temp_input_ { str ( uuid . uuid4 ( ) ) } "
output_folder = f " ./temp_output_ { str ( uuid . uuid4 ( ) ) } "
2025-09-02 11:50:56 +08:00
# Execute the image processing in a separate thread to avoid blocking the event loop
2025-07-10 09:41:26 +08:00
result = await asyncio . to_thread (
process_images ,
yolo_model = yolo_model ,
image_list = image_list ,
class_filter = class_filter ,
input_folder = input_folder ,
output_folder = output_folder ,
minio_info = minio_info
)
2025-09-02 11:50:56 +08:00
# return the result as JSON response
2025-07-10 09:41:26 +08:00
return json_response ( result )
except Exception as e :
logger . error ( f " Error occurred while processing request: { str ( e ) } " , exc_info = True )
return json_response ( {
" status " : " error " ,
" message " : f " Internal server error: { str ( e ) } "
} , status = 500 )
2025-11-11 09:43:25 +08:00
#--------------------------------------------------------------------------yolo训练相关的API----------------------------------------------------------------########################################
#创建yolo训练的蓝图
MAX_CONCURRENT_JOBS = torch . cuda . device_count ( ) if torch . cuda . is_available ( ) else 1
tasks : Dict [ str , Dict ] = { }
task_queue = Queue ( )
active_jobs : List [ str ] = [ ]
lock = threading . Lock ( )
# ------------------ 参数模型 ------------------
class TrainRequest ( BaseModel ) :
config_name : str
table_name : str
column_name : str
search_condition : str
aim_path : str
image_dir : str
label_dir : str
output_path : str
pt_path : str
imgsz : int
epochs : int
device : List [ int ]
hsv_v : float
cos_lr : bool
batch : int
project_dir : str
class_names : List [ str ]
# ------------------ 核心执行函数 ------------------
def run_training ( task_id : str , params : TrainRequest ) :
2025-07-10 09:41:26 +08:00
try :
2025-11-11 09:43:25 +08:00
with lock :
active_jobs . append ( task_id )
tasks [ task_id ] [ " status " ] = " running "
train_main (
config_name = params . config_name ,
table_name = params . table_name ,
column_name = params . column_name ,
search_condition = params . search_condition ,
aim_path = params . aim_path ,
image_dir = params . image_dir ,
label_dir = params . label_dir ,
output_path = params . output_path ,
pt_path = params . pt_path ,
imgsz = params . imgsz ,
epochs = params . epochs ,
device = params . device ,
hsv_v = params . hsv_v ,
cos_lr = params . cos_lr ,
batch = params . batch ,
project_dir = params . project_dir ,
class_names = params . class_names
2025-07-10 09:41:26 +08:00
)
2025-11-11 09:43:25 +08:00
tasks [ task_id ] [ " status " ] = " finished "
2025-07-10 09:41:26 +08:00
except Exception as e :
2025-11-11 09:43:25 +08:00
tasks [ task_id ] [ " status " ] = " failed "
tasks [ task_id ] [ " error " ] = str ( e )
finally :
with lock :
if task_id in active_jobs :
active_jobs . remove ( task_id )
schedule_next_job ( )
# ------------------ 调度器 ------------------
def schedule_next_job ( ) :
with lock :
while len ( active_jobs ) < MAX_CONCURRENT_JOBS and not task_queue . empty ( ) :
next_id = task_queue . get ( )
params = tasks [ next_id ] [ " params " ]
t = threading . Thread ( target = run_training , args = ( next_id , params ) , daemon = True )
t . start ( )
# ------------------ 接口 ------------------
@yolo_tile_blueprint.post ( " /train " )
async def submit_train_job ( request ) :
try :
data = request . json
params = TrainRequest ( * * data )
except ValidationError as e :
return json ( { " success " : False , " error " : e . errors ( ) } )
task_id = str ( uuid . uuid4 ( ) )
tasks [ task_id ] = { " status " : " queued " , " params " : params }
2025-07-10 09:41:26 +08:00
2025-11-11 09:43:25 +08:00
with lock :
if len ( active_jobs ) < MAX_CONCURRENT_JOBS :
t = threading . Thread ( target = run_training , args = ( task_id , params ) , daemon = True )
t . start ( )
else :
task_queue . put ( task_id )
tasks [ task_id ] [ " status " ] = " waiting "
return json ( { " success " : True , " task_id " : task_id , " message " : " 任务已提交 " } )
@yolo_tile_blueprint.get ( " /task_status/<task_id> " )
async def task_status ( request , task_id : str ) :
if task_id not in tasks :
return json ( { " success " : False , " message " : " 任务ID不存在 " } )
task_info = tasks [ task_id ]
return json ( {
" success " : True ,
" status " : task_info [ " status " ] ,
" error " : task_info . get ( " error " , None )
} )
@yolo_tile_blueprint.get ( " /tasks " )
async def all_tasks ( request ) :
return json ( {
tid : { " status " : info [ " status " ] }
for tid , info in tasks . items ( )
} )
@yolo_tile_blueprint.get ( " /system_status " )
async def system_status ( request ) :
gpu_available = torch . cuda . is_available ( )
return json ( {
" gpu_available " : gpu_available ,
" max_concurrent " : MAX_CONCURRENT_JOBS ,
" running_jobs " : len ( active_jobs ) ,
" waiting_jobs " : task_queue . qsize ( ) ,
" active_task_ids " : active_jobs
} )
2025-07-10 09:41:26 +08:00
if __name__ == ' __main__ ' :
app . run ( host = " 0.0.0.0 " , port = 12366 , debug = True , workers = 1 )
2025-10-09 09:29:18 +08:00