diff --git a/Ai_tottle/ai_tottle_api.py b/Ai_tottle/ai_tottle_api.py index 2cb30b1..315055f 100644 --- a/Ai_tottle/ai_tottle_api.py +++ b/Ai_tottle/ai_tottle_api.py @@ -10,7 +10,7 @@ import asyncio from ai_image import process_images # 你实现的图片处理函数 from queue import Queue from map_find import map_process_images -from yolo_train import auto_train +from yolo_train import auto_train,query_progress import torch from yolo_photo import map_process_images_with_progress # 引入你的处理函数 # 日志配置 @@ -261,18 +261,19 @@ async def yolo_train_api(request): "project_name": str } 输出 JSON: - { - "base_metrics": Dict[str, float], - "best_model_path": str, - "final_metrics": Dict[str, float] + return { + "status": "success", + "message": "Train finished", + "project_name": project_name, + "label_count": label_count, + "base_metrics": base_metrics, + "final_metrics": final_metrics } - """ try: # 修改为直接访问 request.json 而不是调用它 data = request.json - if not data: return json_response({"status": "error", "message": "data is required"}, status=400) @@ -290,6 +291,31 @@ async def yolo_train_api(request): "status": "error", "message": f"Internal server error: {str(e)}" }, status=500) + +# 查询训练进度接口 +@yolo_tile_blueprint.get("/progress/") +async def yolo_train_progress(request, project_name): + ''' + 输入参数: + 如果想查询最新一次训练:GET /yolo/progress/my_project + 如果想查询某次特定时间:GET /yolo/progress/my_project?run_time=20250902_1012 + 输出 JSON: + { + "status": "ok", + "run_time": "20250902_1012", + "progress": { + "epoch": 12, + "precision": 0.72, + "recall": 0.64, + "mAP50": 0.68, + "mAP50-95": 0.42 + } + } + + ''' + run_time = request.args.get("run_time") # 可选参数 + result = await asyncio.to_thread(query_progress, project_name, run_time) + return json_response(result) if __name__ == '__main__': app.run(host="0.0.0.0", port=12366, debug=True,workers=1) diff --git a/Ai_tottle/yolo_train.py b/Ai_tottle/yolo_train.py index 0ea5eb7..4e10a99 100644 --- a/Ai_tottle/yolo_train.py +++ b/Ai_tottle/yolo_train.py @@ -21,6 +21,7 @@ import psycopg2 import miniohelp as miniohelp from psycopg2 import OperationalError from collections import Counter +import pandas as pd ##############################################################################计算每个class的label数量############################################################# def count_labels_by_class(label_dir): class_counter = Counter() @@ -291,7 +292,7 @@ def train(project_name, yaml_path, default_model_path): if os.path.exists(trained_model_path): save_last_model(trained_model_path) -#######################################################################################主流程################################################################################## +#######################################################################################自动训练################################################################################## def auto_train( db_host, db_database, @@ -352,6 +353,43 @@ def auto_train( "final_metrics": final_metrics } +##########################################################################################查询训练进度################################################################## +def query_progress(project_name, run_time=None): + """ + 查询训练进度 + :param project_name: 训练工程名 + :param run_time: 启动时间戳(默认取最近一次) + :return: 当前进度信息 + """ + base_dir = os.path.join("runs", "detect", project_name) + if not os.path.exists(base_dir): + return {"status": "not_found", "message": f"没有找到 {base_dir}"} + + # 如果没有指定 run_time,取最新目录 + if run_time is None: + dirs = sorted(os.listdir(base_dir), reverse=True) + if not dirs: + return {"status": "not_found", "message": "没有找到训练记录"} + run_time = dirs[0] + + log_path = os.path.join(base_dir, run_time, "results.csv") + if not os.path.exists(log_path): + return {"status": "not_found", "message": f"没有找到日志 {log_path}"} + + df = pd.read_csv(log_path) + if df.empty: + return {"status": "running", "message": "日志尚未写入数据"} + + last_row = df.iloc[-1].to_dict() + progress = { + "epoch": int(last_row.get("epoch", -1)), + "precision": float(last_row.get("precision", 0.0)), + "recall": float(last_row.get("recall", 0.0)), + "mAP50": float(last_row.get("mAP50", 0.0)), + "mAP50-95": float(last_row.get("mAP50-95", 0.0)), + } + return {"status": "ok", "run_time": run_time, "progress": progress} + if __name__ == '__main__': auto_train( db_host='222.212.85.86',