yoooooger

This commit is contained in:
yooooger 2025-09-02 10:23:21 +08:00
parent c4572d32fa
commit 16a353b1ee
2 changed files with 72 additions and 8 deletions

View File

@ -10,7 +10,7 @@ import asyncio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
from map_find import map_process_images
from yolo_train import auto_train
from yolo_train import auto_train,query_progress
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
@ -261,18 +261,19 @@ async def yolo_train_api(request):
"project_name": str
}
输出 JSON:
{
"base_metrics": Dict[str, float],
"best_model_path": str,
"final_metrics": Dict[str, float]
return {
"status": "success",
"message": "Train finished",
"project_name": project_name,
"label_count": label_count,
"base_metrics": base_metrics,
"final_metrics": final_metrics
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
@ -291,5 +292,30 @@ async def yolo_train_api(request):
"message": f"Internal server error: {str(e)}"
}, status=500)
# 查询训练进度接口
@yolo_tile_blueprint.get("/progress/<project_name>")
async def yolo_train_progress(request, project_name):
'''
输入参数:
如果想查询最新一次训练GET /yolo/progress/my_project
如果想查询某次特定时间GET /yolo/progress/my_project?run_time=20250902_1012
输出 JSON:
{
"status": "ok",
"run_time": "20250902_1012",
"progress": {
"epoch": 12,
"precision": 0.72,
"recall": 0.64,
"mAP50": 0.68,
"mAP50-95": 0.42
}
}
'''
run_time = request.args.get("run_time") # 可选参数
result = await asyncio.to_thread(query_progress, project_name, run_time)
return json_response(result)
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)

View File

@ -21,6 +21,7 @@ import psycopg2
import miniohelp as miniohelp
from psycopg2 import OperationalError
from collections import Counter
import pandas as pd
##############################################################################计算每个class的label数量#############################################################
def count_labels_by_class(label_dir):
class_counter = Counter()
@ -291,7 +292,7 @@ def train(project_name, yaml_path, default_model_path):
if os.path.exists(trained_model_path):
save_last_model(trained_model_path)
#######################################################################################主流程##################################################################################
#######################################################################################自动训练##################################################################################
def auto_train(
db_host,
db_database,
@ -352,6 +353,43 @@ def auto_train(
"final_metrics": final_metrics
}
##########################################################################################查询训练进度##################################################################
def query_progress(project_name, run_time=None):
"""
查询训练进度
:param project_name: 训练工程名
:param run_time: 启动时间戳默认取最近一次
:return: 当前进度信息
"""
base_dir = os.path.join("runs", "detect", project_name)
if not os.path.exists(base_dir):
return {"status": "not_found", "message": f"没有找到 {base_dir}"}
# 如果没有指定 run_time取最新目录
if run_time is None:
dirs = sorted(os.listdir(base_dir), reverse=True)
if not dirs:
return {"status": "not_found", "message": "没有找到训练记录"}
run_time = dirs[0]
log_path = os.path.join(base_dir, run_time, "results.csv")
if not os.path.exists(log_path):
return {"status": "not_found", "message": f"没有找到日志 {log_path}"}
df = pd.read_csv(log_path)
if df.empty:
return {"status": "running", "message": "日志尚未写入数据"}
last_row = df.iloc[-1].to_dict()
progress = {
"epoch": int(last_row.get("epoch", -1)),
"precision": float(last_row.get("precision", 0.0)),
"recall": float(last_row.get("recall", 0.0)),
"mAP50": float(last_row.get("mAP50", 0.0)),
"mAP50-95": float(last_row.get("mAP50-95", 0.0)),
}
return {"status": "ok", "run_time": run_time, "progress": progress}
if __name__ == '__main__':
auto_train(
db_host='222.212.85.86',