diff --git a/Ai_tottle/ai_tottle_api.py b/Ai_tottle/ai_tottle_api.py index 3230506..eaf1d88 100644 --- a/Ai_tottle/ai_tottle_api.py +++ b/Ai_tottle/ai_tottle_api.py @@ -125,7 +125,7 @@ yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/') app.blueprint(yolo_tile_blueprint) # YOLO URL APT -# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化) +# save the task progress in memory @yolo_tile_blueprint.post("/process_images") async def process_images(request): diff --git a/Ai_tottle/yolo_train.py b/Ai_tottle/yolo_train.py index dfd321f..771e307 100644 --- a/Ai_tottle/yolo_train.py +++ b/Ai_tottle/yolo_train.py @@ -22,7 +22,7 @@ import miniohelp as miniohelp from psycopg2 import OperationalError from collections import Counter import pandas as pd -##############################################################################计算每个class的label数量############################################################# +####################################### 𝓒𝓸𝓶𝓹𝓾𝓽𝓮 𝓵𝓪𝓫𝓮𝓵 𝓬𝓸𝓾𝓷𝓽 𝓯𝓸𝓻 𝓮𝓪𝓬𝓱 𝓬𝓵𝓪𝓼𝓼 ############################## def count_labels_by_class(label_dir): class_counter = Counter() for file in os.listdir(label_dir): @@ -32,7 +32,7 @@ def count_labels_by_class(label_dir): class_id = line.strip().split()[0] class_counter[class_id] += 1 return dict(class_counter) -#################################################################################统计原始的召回率和准确率######################################################### +##################################### 𝓢𝓽𝓪𝓽𝓲𝓼𝓽𝓲𝓬𝓼 𝓸𝓯 𝓹𝓻𝓮𝓬𝓲𝓼𝓲𝓸𝓷 𝓪𝓷𝓭 𝓻𝓮𝓬𝓪𝓵𝓵 ##################################### def evaluate_model_per_class(model_path, dataset_yaml, class_names): model = YOLO(model_path) metrics = model.val(data=dataset_yaml, split='val') @@ -49,7 +49,8 @@ def evaluate_model_per_class(model_path, dataset_yaml, class_names): } return results -#############################################################################下载图片与对应的yolo格式标签txt文档################################################### +##################################### 𝓓𝓸𝔀𝓷𝓵𝓸𝓪𝓭 𝓲𝓶𝓪𝓰𝓮𝓼 & 𝔂𝓸𝓵𝓸 𝓽𝓮𝔁𝓽 𝓵𝓪𝓫𝓮𝓵 𝓼######################################## + def link_database(db_database, db_user, db_password, db_host, db_port, search_query): """从数据库查询并更新符合条件的记录""" try: @@ -285,7 +286,7 @@ def train(project_name, yaml_path, default_model_path): device=[0], # 如果你有多个显卡可以写 [0,1] workers=0, project=project, - name=current_date, + name=current_date, ) trained_model_path = os.path.join('runs', 'detect', current_date, 'weights', 'best.pt') @@ -353,7 +354,7 @@ def auto_train( "final_metrics": final_metrics } -##########################################################################################查询训练进度################################################################## +##########################################################################################get the training ################################################################## def query_progress(project_name, run_time=None): """ 查询训练进度 @@ -369,7 +370,8 @@ def query_progress(project_name, run_time=None): if run_time is None: dirs = sorted(os.listdir(base_dir), reverse=True) if not dirs: - return {"status": "not_found", "message": "Documenta exercitationis non inventa sunt"} + # 𝒟𝑜𝒸𝓊𝓂𝑒𝓃𝓉𝒶 𝑒𝓍𝑒𝓇𝒸𝒾𝓉𝒶𝓉𝒾𝑜𝓃𝒾𝓈 𝓃𝑜𝓃 𝒾𝓃𝓋𝑒𝓃𝓉𝒶 𝓈𝓊𝓃𝓉 + return {"status": "not_found", "message": "Documenta exercitationis adhuc inveniuntur non potest"} run_time = dirs[0] log_path = os.path.join(base_dir, run_time, "results.csv")