yoooooger
This commit is contained in:
parent
c4572d32fa
commit
16a353b1ee
@ -10,7 +10,7 @@ import asyncio
|
||||
from ai_image import process_images # 你实现的图片处理函数
|
||||
from queue import Queue
|
||||
from map_find import map_process_images
|
||||
from yolo_train import auto_train
|
||||
from yolo_train import auto_train,query_progress
|
||||
import torch
|
||||
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
|
||||
# 日志配置
|
||||
@ -261,18 +261,19 @@ async def yolo_train_api(request):
|
||||
"project_name": str
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"base_metrics": Dict[str, float],
|
||||
"best_model_path": str,
|
||||
"final_metrics": Dict[str, float]
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Train finished",
|
||||
"project_name": project_name,
|
||||
"label_count": label_count,
|
||||
"base_metrics": base_metrics,
|
||||
"final_metrics": final_metrics
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
# 修改为直接访问 request.json 而不是调用它
|
||||
data = request.json
|
||||
|
||||
|
||||
if not data:
|
||||
return json_response({"status": "error", "message": "data is required"}, status=400)
|
||||
|
||||
@ -291,5 +292,30 @@ async def yolo_train_api(request):
|
||||
"message": f"Internal server error: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# 查询训练进度接口
|
||||
@yolo_tile_blueprint.get("/progress/<project_name>")
|
||||
async def yolo_train_progress(request, project_name):
|
||||
'''
|
||||
输入参数:
|
||||
如果想查询最新一次训练:GET /yolo/progress/my_project
|
||||
如果想查询某次特定时间:GET /yolo/progress/my_project?run_time=20250902_1012
|
||||
输出 JSON:
|
||||
{
|
||||
"status": "ok",
|
||||
"run_time": "20250902_1012",
|
||||
"progress": {
|
||||
"epoch": 12,
|
||||
"precision": 0.72,
|
||||
"recall": 0.64,
|
||||
"mAP50": 0.68,
|
||||
"mAP50-95": 0.42
|
||||
}
|
||||
}
|
||||
|
||||
'''
|
||||
run_time = request.args.get("run_time") # 可选参数
|
||||
result = await asyncio.to_thread(query_progress, project_name, run_time)
|
||||
return json_response(result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
||||
|
@ -21,6 +21,7 @@ import psycopg2
|
||||
import miniohelp as miniohelp
|
||||
from psycopg2 import OperationalError
|
||||
from collections import Counter
|
||||
import pandas as pd
|
||||
##############################################################################计算每个class的label数量#############################################################
|
||||
def count_labels_by_class(label_dir):
|
||||
class_counter = Counter()
|
||||
@ -291,7 +292,7 @@ def train(project_name, yaml_path, default_model_path):
|
||||
if os.path.exists(trained_model_path):
|
||||
save_last_model(trained_model_path)
|
||||
|
||||
#######################################################################################主流程##################################################################################
|
||||
#######################################################################################自动训练##################################################################################
|
||||
def auto_train(
|
||||
db_host,
|
||||
db_database,
|
||||
@ -352,6 +353,43 @@ def auto_train(
|
||||
"final_metrics": final_metrics
|
||||
}
|
||||
|
||||
##########################################################################################查询训练进度##################################################################
|
||||
def query_progress(project_name, run_time=None):
|
||||
"""
|
||||
查询训练进度
|
||||
:param project_name: 训练工程名
|
||||
:param run_time: 启动时间戳(默认取最近一次)
|
||||
:return: 当前进度信息
|
||||
"""
|
||||
base_dir = os.path.join("runs", "detect", project_name)
|
||||
if not os.path.exists(base_dir):
|
||||
return {"status": "not_found", "message": f"没有找到 {base_dir}"}
|
||||
|
||||
# 如果没有指定 run_time,取最新目录
|
||||
if run_time is None:
|
||||
dirs = sorted(os.listdir(base_dir), reverse=True)
|
||||
if not dirs:
|
||||
return {"status": "not_found", "message": "没有找到训练记录"}
|
||||
run_time = dirs[0]
|
||||
|
||||
log_path = os.path.join(base_dir, run_time, "results.csv")
|
||||
if not os.path.exists(log_path):
|
||||
return {"status": "not_found", "message": f"没有找到日志 {log_path}"}
|
||||
|
||||
df = pd.read_csv(log_path)
|
||||
if df.empty:
|
||||
return {"status": "running", "message": "日志尚未写入数据"}
|
||||
|
||||
last_row = df.iloc[-1].to_dict()
|
||||
progress = {
|
||||
"epoch": int(last_row.get("epoch", -1)),
|
||||
"precision": float(last_row.get("precision", 0.0)),
|
||||
"recall": float(last_row.get("recall", 0.0)),
|
||||
"mAP50": float(last_row.get("mAP50", 0.0)),
|
||||
"mAP50-95": float(last_row.get("mAP50-95", 0.0)),
|
||||
}
|
||||
return {"status": "ok", "run_time": run_time, "progress": progress}
|
||||
|
||||
if __name__ == '__main__':
|
||||
auto_train(
|
||||
db_host='222.212.85.86',
|
||||
|
Loading…
x
Reference in New Issue
Block a user