From 23b990c811972f63fd52beb5d2a806789eced85e Mon Sep 17 00:00:00 2001 From: yooooger <761181201@qq.com> Date: Fri, 26 Sep 2025 18:03:49 +0800 Subject: [PATCH] loggoing --- Ai_tottle/yolo_train.py | 349 +++++++++++++++++----------------------- 1 file changed, 148 insertions(+), 201 deletions(-) diff --git a/Ai_tottle/yolo_train.py b/Ai_tottle/yolo_train.py index 771e307..1e93c61 100644 --- a/Ai_tottle/yolo_train.py +++ b/Ai_tottle/yolo_train.py @@ -1,28 +1,104 @@ +"""" +main() + | + v +setup_logger(project) + | + v +get_last_model_path(project) + | + v ++-------------------------+ +| 有 last.pt | 无 last.pt | ++-------------------------+ + | | + v v +load_last_model() start_new_training() + | | + +-------+--------+ + | + v + check_dataset(root) + | + v + split_dataset(root, ratios) + | + v + clean_labels(root) + | + v + generate_yaml(dataset_dir) + | + v + train_yolo(model, data_yaml) + | + v + 保存 last.pt + | + v +logger.info("Saved last model path") + | + v + 写入 logs/{project}.log + + + """ -本程序用以自动训练ai -1.下载图片与对应的yolo格式标签txt文档 -2.保证文件内容格式正确 -3.分割数据集 -""" + import os import shutil import datetime import torch from ultralytics import YOLO -import os -import shutil import random -from tqdm import tqdm import math import stat import yaml -import os import psycopg2 -import miniohelp as miniohelp from psycopg2 import OperationalError from collections import Counter import pandas as pd -####################################### 𝓒𝓸𝓶𝓹𝓾𝓽𝓮 𝓵𝓪𝓫𝓮𝓵 𝓬𝓸𝓾𝓷𝓽 𝓯𝓸𝓻 𝓮𝓪𝓬𝓱 𝓬𝓵𝓪𝓼𝓼 ############################## +import logging +from tqdm import tqdm +import miniohelp as miniohelp + +######################################## Logging ######################################## +def setup_logger(project: str): + os.makedirs("logs", exist_ok=True) + log_file = os.path.join("logs", f"{project}.log") + + logger = logging.getLogger(project) + if not logger.handlers: + logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") + + fh = logging.FileHandler(log_file, encoding="utf-8") + fh.setFormatter(formatter) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + + logger.addHandler(fh) + logger.addHandler(sh) + + return logger + +def get_last_model_from_log(project: str, default_model: str): + """从日志中解析上一次训练的 last.pt 路径""" + log_file = os.path.join("logs", f"{project}.log") + if not os.path.exists(log_file): + return default_model + + with open(log_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + for line in reversed(lines): + if "Saved last model path:" in line: + path = line.strip().split("Saved last model path:")[-1].strip() + if os.path.exists(path): + return path + return default_model + +####################################### 工具函数 ####################################### def count_labels_by_class(label_dir): class_counter = Counter() for file in os.listdir(label_dir): @@ -32,12 +108,11 @@ def count_labels_by_class(label_dir): class_id = line.strip().split()[0] class_counter[class_id] += 1 return dict(class_counter) -##################################### 𝓢𝓽𝓪𝓽𝓲𝓼𝓽𝓲𝓬𝓼 𝓸𝓯 𝓹𝓻𝓮𝓬𝓲𝓼𝓲𝓸𝓷 𝓪𝓷𝓭 𝓻𝓮𝓬𝓪𝓵𝓵 ##################################### + def evaluate_model_per_class(model_path, dataset_yaml, class_names): model = YOLO(model_path) metrics = model.val(data=dataset_yaml, split='val') - - class_ids = range(len(metrics.box.p)) # class id list + class_ids = range(len(metrics.box.p)) results = {} for i in class_ids: name = class_names.get(str(i), str(i)) @@ -49,10 +124,7 @@ def evaluate_model_per_class(model_path, dataset_yaml, class_names): } return results -##################################### 𝓓𝓸𝔀𝓷𝓵𝓸𝓪𝓭 𝓲𝓶𝓪𝓰𝓮𝓼 & 𝔂𝓸𝓵𝓸 𝓽𝓮𝔁𝓽 𝓵𝓪𝓫𝓮𝓵 𝓼######################################## - def link_database(db_database, db_user, db_password, db_host, db_port, search_query): - """从数据库查询并更新符合条件的记录""" try: with psycopg2.connect( database=db_database, @@ -70,22 +142,11 @@ def link_database(db_database, db_user, db_password, db_host, db_port, search_qu except Exception as e: print(f"发生了其他错误: {e}") -def down_dataset(db_database, db_user, db_password, db_host, db_port, model): - ''' - db_database: 数据库名 - db_user: 数据库用户名 - db_password: 数据库密码 - db_host: 数据库地址 - db_port: 数据库端口 - model: 模型ID - ''' - search_query = f""" - SELECT * FROM aidataset - WHERE model = '{model}'; - """ +def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger): + search_query = f"SELECT * FROM aidataset WHERE model = '{model}';" records = link_database(db_database, db_user, db_password, db_host, db_port, search_query) if not records: - print("没有查询到数据。") + logger.warning("没有查询到数据。") return os.makedirs('./dataset/images', exist_ok=True) @@ -95,31 +156,27 @@ def down_dataset(db_database, db_user, db_password, db_host, db_port, model): img_path = r[4] label_content = r[5] - # 下载图片到 images 文件夹 local_img_name = img_path.split('/')[-1] local_img_path = os.path.join('./dataset/images', local_img_name) miniohelp.downFile(img_path, local_img_path) - # 写标签到 labels 文件夹 txt_name = os.path.splitext(local_img_name)[0] + '.txt' txt_path = os.path.join('./dataset/labels', txt_name) with open(txt_path, 'w', encoding='utf-8') as f: f.write(label_content + '\n') -#############################################################################保证文件内容格式正确########################################################### + + logger.info("数据下载完成") def make_writable(file_path): os.chmod(file_path, stat.S_IWRITE) -def process_files_in_folder(folder_path): +def process_files_in_folder(folder_path, logger): for root, _, files in os.walk(folder_path): for file_name in files: if file_name.endswith('.txt'): file_path = os.path.join(root, file_name) - - # 确保文件可写 make_writable(file_path) - # 读取文件内容并进行处理 with open(file_path, 'r') as file: lines = file.readlines() @@ -127,62 +184,42 @@ def process_files_in_folder(folder_path): for line in lines: numbers = line.split() processed_numbers = [] - - # 确保第一列为整数 0 或 1,不处理为浮点数 - if numbers[0] == '0' or numbers[0] == '1' or numbers[0] == '2' or numbers[0] == '3' or numbers[0] == '4' or numbers[0] == '5' or numbers[0] == '6' or numbers[0] == '7' or numbers[0] == '8' or numbers[0] == '9': + if numbers[0].isdigit(): processed_numbers.append(numbers[0]) else: - print(f"Unexpected value in first column: {numbers[0]}") + logger.warning(f"Unexpected value in first column: {numbers[0]}") continue - # 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据 - skip_line = False # 用于标记是否跳过这一行 + skip_line = False for number in numbers[1:]: try: number = float(number) - if math.isnan(number): # 检查是否为NaN + if math.isnan(number): skip_line = True - print(f"NaN detected in file: {file_path}, line: {line}") + logger.warning(f"NaN detected in {file_path}: {line}") break if number < 0: - number = abs(number) # 将负数转换为正数 - processed_numbers.append(str(number)) # 保留原始格式 + number = abs(number) + processed_numbers.append(str(number)) except ValueError: - processed_numbers.append(number) # 非数字列保持原样 - - # 如果该行没有NaN数据,则加入结果列表 + processed_numbers.append(number) + if not skip_line: processed_line = ' '.join(processed_numbers) processed_lines.append(processed_line) - # 将处理后的内容写回文件 with open(file_path, 'w') as file: file.write('\n'.join(processed_lines)) - print(f"Finished processing: {file_path}") - -###################################################################################分割数据集##################################################################### -def split_img(img_path, label_path, split_list, new_path, class_names): + logger.info(f"Processed {file_path}") + +def split_img(img_path, label_path, split_list, new_path, class_names, logger): try: Data = os.path.abspath(new_path) os.makedirs(Data, exist_ok=True) - - train_img_dir = os.path.abspath(os.path.join(Data, 'train', 'images')) - val_img_dir = os.path.abspath(os.path.join(Data, 'val', 'images')) - test_img_dir = os.path.abspath(os.path.join(Data, 'test', 'images')) - - train_label_dir = os.path.abspath(os.path.join(Data, 'train', 'labels')) - val_label_dir = os.path.abspath(os.path.join(Data, 'val', 'labels')) - test_label_dir = os.path.abspath(os.path.join(Data, 'test', 'labels')) - - os.makedirs(train_img_dir, exist_ok=True) - os.makedirs(train_label_dir, exist_ok=True) - os.makedirs(val_img_dir, exist_ok=True) - os.makedirs(val_label_dir, exist_ok=True) - os.makedirs(test_img_dir, exist_ok=True) - os.makedirs(test_label_dir, exist_ok=True) - + dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels'] + for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True) except Exception as e: - print(f'文件目录创建失败: {e}') + logger.error(f'文件目录创建失败: {e}') return train, val, test = split_list @@ -192,30 +229,31 @@ def split_img(img_path, label_path, split_list, new_path, class_names): train_img = random.sample(all_img_path, int(train * len(all_img_path))) train_label = [toLabelPath(img, label_path) for img in train_img] for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'): - _copy(train_img[i], train_img_dir) - _copy(train_label[i], train_label_dir) + _copy(train_img[i], os.path.join(Data,'train/images')) + _copy(train_label[i], os.path.join(Data,'train/labels')) all_img_path.remove(train_img[i]) val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path))) val_label = [toLabelPath(img, label_path) for img in val_img] for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'): - _copy(val_img[i], val_img_dir) - _copy(val_label[i], val_label_dir) + _copy(val_img[i], os.path.join(Data,'val/images')) + _copy(val_label[i], os.path.join(Data,'val/labels')) all_img_path.remove(val_img[i]) test_img = all_img_path test_label = [toLabelPath(img, label_path) for img in test_img] for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'): - _copy(test_img[i], test_img_dir) - _copy(test_label[i], test_label_dir) + _copy(test_img[i], os.path.join(Data,'test/images')) + _copy(test_label[i], os.path.join(Data,'test/labels')) generate_dataset_yaml( save_path=os.path.join(Data, 'dataset.yaml'), - train_path=train_img_dir, - val_path=val_img_dir, - test_path=test_img_dir, + train_path=os.path.join(Data,'train/images'), + val_path=os.path.join(Data,'val/images'), + test_path=os.path.join(Data,'test/images'), class_names=class_names ) + logger.info("数据集划分完成") def _copy(from_path, to_path): try: @@ -236,162 +274,71 @@ def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_name 'nc': len(class_names), 'names': list(class_names.values()) } - with open(save_path, 'w', encoding='utf-8') as f: yaml.dump(dataset_yaml, f, allow_unicode=True) -######################################################################开训开训开训################################################################## - -# 获取上次训练的模型路径 -def get_last_model(default_model_path, log_file='last_model.log'): - if os.path.exists(log_file): - with open(log_file, 'r') as f: - last_model = f.read().strip() - if last_model and os.path.exists(last_model): - print(f"使用上次训练模型: {last_model}") - return last_model - print(f"使用默认模型: {default_model_path}") - return default_model_path - -# 保存本次训练的模型路径 -def save_last_model(model_path, log_file='last_model.log'): - with open(log_file, 'w') as f: - f.write(model_path) - -# 删除文件夹 -def delete_folder(folder_path): +def delete_folder(folder_path, logger): if os.path.exists(folder_path): shutil.rmtree(folder_path) - print(f"已删除文件夹及其内容: {folder_path}") - else: - print(f"文件夹不存在: {folder_path}") + logger.info(f"已删除文件夹: {folder_path}") -# 训练函数 -def train(project_name, yaml_path, default_model_path): +####################################### 训练 ####################################### +def train(project_name, yaml_path, default_model_path, logger): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") + logger.info(f"Using device: {device}") - model_path = get_last_model(default_model_path) + model_path = get_last_model_from_log(project_name, default_model_path) + logger.info(f"加载模型: {model_path}") model = YOLO(model_path).to(device) current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M") - project = project_name.strip() - model.train( data=yaml_path, epochs=200, pretrained=True, patience=50, imgsz=640, - device=[0], # 如果你有多个显卡可以写 [0,1] + device=[0], workers=0, - project=project, + project=project_name, name=current_date, ) - trained_model_path = os.path.join('runs', 'detect', current_date, 'weights', 'best.pt') + trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt') if os.path.exists(trained_model_path): - save_last_model(trained_model_path) + logger.info(f"Saved last model path: {trained_model_path}") -#######################################################################################自动训练################################################################################## -def auto_train( - db_host, - db_database, - db_user, - db_password, - db_port, - model_id, - img_path='./dataset/images', - label_path='./dataset/labels', - new_path='./datasets', - split_list=[0.7, 0.2, 0.1], - class_names=None, - project_name='default_project' -): +####################################### 自动训练 ####################################### +def auto_train(db_host, db_database, db_user, db_password, db_port, model_id, + img_path='./dataset/images', label_path='./dataset/labels', + new_path='./datasets', split_list=[0.7, 0.2, 0.1], + class_names=None, project_name='default_project'): if class_names is None: class_names = {} - # 删除旧数据 - delete_folder('dataset') - delete_folder('datasets') + logger = setup_logger(project_name) - # 下载新数据 - down_dataset(db_database, db_user, db_password, db_host, db_port, model_id) + delete_folder('dataset', logger) + delete_folder('datasets', logger) - # 处理 label - process_files_in_folder(img_path) + down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger) + process_files_in_folder(img_path, logger) - # 统计标签数量 label_count = count_labels_by_class(label_path) + logger.info(f"标签统计: {label_count}") - # 划分数据集 - split_img(img_path, label_path, split_list, new_path, class_names) + split_img(img_path, label_path, split_list, new_path, class_names, logger) - # 评估训练前模型 base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names) + logger.info(f"训练前基线评估: {base_metrics}") - # 删除原始数据 - delete_folder('dataset') + delete_folder('dataset', logger) - # 开始训练 - train(project_name, './datasets/dataset.yaml', 'yolo11n.pt') + train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger) - # 训练后评估 - best_model_path = f"runs/detect/{project_name}/weights/best.pt" - final_metrics = evaluate_model_per_class(best_model_path, './datasets/dataset.yaml', class_names) - - # 删除 datasets - delete_folder('datasets') - - print("训练流程执行完成!") - - return { - "status": "success", - "message": "Train finished", - "project_name": project_name, - "label_count": label_count, - "base_metrics": base_metrics, - "final_metrics": final_metrics - } - -##########################################################################################get the training ################################################################## -def query_progress(project_name, run_time=None): - """ - 查询训练进度 - :param project_name: 训练工程名 - :param run_time: 启动时间戳(默认取最近一次) - :return: 当前进度信息 - """ - base_dir = os.path.join("runs", "detect", project_name) - if not os.path.exists(base_dir): - return {"status": "not_found", "message": f"没有找到 {base_dir}"} - - # 如果没有指定 run_time,取最新目录 - if run_time is None: - dirs = sorted(os.listdir(base_dir), reverse=True) - if not dirs: - # 𝒟𝑜𝒸𝓊𝓂𝑒𝓃𝓉𝒶 𝑒𝓍𝑒𝓇𝒸𝒾𝓉𝒶𝓉𝒾𝑜𝓃𝒾𝓈 𝓃𝑜𝓃 𝒾𝓃𝓋𝑒𝓃𝓉𝒶 𝓈𝓊𝓃𝓉 - return {"status": "not_found", "message": "Documenta exercitationis adhuc inveniuntur non potest"} - run_time = dirs[0] - - log_path = os.path.join(base_dir, run_time, "results.csv") - if not os.path.exists(log_path): - return {"status": "not_found", "message": f"Documenta exercitationis non inventa sunt {log_path}"} - - df = pd.read_csv(log_path) - if df.empty: - return {"status": "running", "message": "日志尚未写入数据"} - - last_row = df.iloc[-1].to_dict() - progress = { - "epoch": int(last_row.get("epoch", -1)), - "precision": float(last_row.get("precision", 0.0)), - "recall": float(last_row.get("recall", 0.0)), - "mAP50": float(last_row.get("mAP50", 0.0)), - "mAP50-95": float(last_row.get("mAP50-95", 0.0)), - } - return {"status": "ok", "run_time": run_time, "progress": progress} + logger.info("训练流程执行完成") +####################################### 主入口 ####################################### if __name__ == '__main__': auto_train( db_host='222.212.85.86',