"""" main() | v setup_logger(project) | v get_last_model_path(project) | v +-------------------------+ | 有 last.pt | 无 last.pt | +-------------------------+ | | v v load_last_model() start_new_training() | | +-------+--------+ | v check_dataset(root) | v split_dataset(root, ratios) | v clean_labels(root) | v generate_yaml(dataset_dir) | v train_yolo(model, data_yaml) | v 保存 last.pt | v logger.info("Saved last model path") | v 写入 logs/{project}.log """ import os import shutil import datetime import torch from ultralytics import YOLO import random import math import stat import yaml import psycopg2 from psycopg2 import OperationalError from collections import Counter import pandas as pd import logging from tqdm import tqdm import miniohelp as miniohelp from aboutdataset.download_oss import download_and_save_images_from_oss ######################################## Logging ######################################## def setup_logger(project: str): os.makedirs("logs", exist_ok=True) log_file = os.path.join("logs", f"{project}.log") logger = logging.getLogger(project) if not logger.handlers: logger.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") fh = logging.FileHandler(log_file, encoding="utf-8") fh.setFormatter(formatter) sh = logging.StreamHandler() sh.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(sh) return logger def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"): """ 从日志解析上一次训练的 last.pt 路径 如果找不到则返回 default_model 支持 default_model 为 .pt 或 .yaml """ log_file = os.path.join("logs", f"{project}.log") if not os.path.exists(log_file): return default_model with open(log_file, "r", encoding="utf-8") as f: lines = f.readlines() for line in reversed(lines): if "Saved last model path:" in line: path = line.strip().split("Saved last model path:")[-1].strip() if os.path.exists(path): return path return default_model def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name): """ yaml_name='config' where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'" image_dir='images' label_dir='labels' table_name = 'aidataset' Returns: (image_dir, label_dir) """ download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name) return image_dir, label_dir ####################################### 工具函数 ####################################### def count_labels_by_class(label_dir): class_counter = Counter() for file in os.listdir(label_dir): if file.endswith('.txt'): with open(os.path.join(label_dir, file), 'r') as f: for line in f: class_id = line.strip().split()[0] class_counter[class_id] += 1 return dict(class_counter) def evaluate_model_per_class(model_path, dataset_yaml, class_names): model = YOLO(model_path) metrics = model.val(data=dataset_yaml, split='val') class_ids = range(len(metrics.box.p)) results = {} for i in class_ids: name = class_names.get(str(i), str(i)) results[name] = { "precision": float(metrics.box.p[i]), "recall": float(metrics.box.r[i]), "mAP50": float(metrics.box.map50[i]), "mAP50_95": float(metrics.box.map[i]) } return results def link_database(db_database, db_user, db_password, db_host, db_port, search_query): try: with psycopg2.connect( database=db_database, user=db_user, password=db_password, host=db_host, port=db_port ) as conn: with conn.cursor() as cur: cur.execute(search_query) records = cur.fetchall() return records except OperationalError as e: print(f"数据库连接或查询时发生错误: {e}") except Exception as e: print(f"发生了其他错误: {e}") def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger): search_query = f"SELECT * FROM aidataset WHERE model = '{model}';" records = link_database(db_database, db_user, db_password, db_host, db_port, search_query) if not records: logger.warning("没有查询到数据。") return os.makedirs('./dataset/images', exist_ok=True) os.makedirs('./dataset/labels', exist_ok=True) for r in records: img_path = r[4] label_content = r[5] local_img_name = img_path.split('/')[-1] local_img_path = os.path.join('./dataset/images', local_img_name) miniohelp.downFile(img_path, local_img_path) txt_name = os.path.splitext(local_img_name)[0] + '.txt' txt_path = os.path.join('./dataset/labels', txt_name) with open(txt_path, 'w', encoding='utf-8') as f: f.write(label_content + '\n') logger.info("数据下载完成") def make_writable(file_path): os.chmod(file_path, stat.S_IWRITE) def process_files_in_folder(folder_path, logger): for root, _, files in os.walk(folder_path): for file_name in files: if file_name.endswith('.txt'): file_path = os.path.join(root, file_name) make_writable(file_path) with open(file_path, 'r') as file: lines = file.readlines() processed_lines = [] for line in lines: numbers = line.split() processed_numbers = [] if numbers[0].isdigit(): processed_numbers.append(numbers[0]) else: logger.warning(f"Unexpected value in first column: {numbers[0]}") continue skip_line = False for number in numbers[1:]: try: number = float(number) if math.isnan(number): skip_line = True logger.warning(f"NaN detected in {file_path}: {line}") break if number < 0: number = abs(number) processed_numbers.append(str(number)) except ValueError: processed_numbers.append(number) if not skip_line: processed_line = ' '.join(processed_numbers) processed_lines.append(processed_line) with open(file_path, 'w') as file: file.write('\n'.join(processed_lines)) logger.info(f"Processed {file_path}") def split_img(img_path, label_path, split_list, new_path, class_names, logger): try: Data = os.path.abspath(new_path) os.makedirs(Data, exist_ok=True) dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels'] for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True) except Exception as e: logger.error(f'文件目录创建失败: {e}') return train, val, test = split_list all_img = os.listdir(img_path) all_img_path = [os.path.join(img_path, img) for img in all_img] train_img = random.sample(all_img_path, int(train * len(all_img_path))) train_label = [toLabelPath(img, label_path) for img in train_img] for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'): _copy(train_img[i], os.path.join(Data,'train/images')) _copy(train_label[i], os.path.join(Data,'train/labels')) all_img_path.remove(train_img[i]) val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path))) val_label = [toLabelPath(img, label_path) for img in val_img] for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'): _copy(val_img[i], os.path.join(Data,'val/images')) _copy(val_label[i], os.path.join(Data,'val/labels')) all_img_path.remove(val_img[i]) test_img = all_img_path test_label = [toLabelPath(img, label_path) for img in test_img] for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'): _copy(test_img[i], os.path.join(Data,'test/images')) _copy(test_label[i], os.path.join(Data,'test/labels')) generate_dataset_yaml( save_path=os.path.join(Data, 'dataset.yaml'), train_path=os.path.join(Data,'train/images'), val_path=os.path.join(Data,'val/images'), test_path=os.path.join(Data,'test/images'), class_names=class_names ) logger.info("数据集划分完成") def _copy(from_path, to_path): try: shutil.copy(from_path, to_path) except Exception as e: print(f"复制文件时出错: {e}") def toLabelPath(img_path, label_path): img = os.path.basename(img_path) label = img.replace('.jpg', '.txt') return os.path.join(label_path, label) def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names): dataset_yaml = { 'train': train_path.replace('\\', '/'), 'val': val_path.replace('\\', '/'), 'test': test_path.replace('\\', '/'), 'nc': len(class_names), 'names': list(class_names.values()) } with open(save_path, 'w', encoding='utf-8') as f: yaml.dump(dataset_yaml, f, allow_unicode=True) def delete_folder(folder_path, logger): if os.path.exists(folder_path): shutil.rmtree(folder_path) logger.info(f"已删除文件夹: {folder_path}") ####################################### 训练 ####################################### def train(project_name, yaml_path, default_model_path, logger): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") model_path = get_last_model_from_log(project_name, default_model_path) logger.info(f"加载模型: {model_path}") model = YOLO(model_path).to(device) current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M") model.train( data=yaml_path, epochs=200, pretrained=True, patience=50, imgsz=640, device=[0], workers=0, project=project_name, name=current_date, ) trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt') if os.path.exists(trained_model_path): logger.info(f"Saved last model path: {trained_model_path}") ####################################### 自动训练 ####################################### def auto_train(db_host, db_database, db_user, db_password, db_port, model_id, img_path='./dataset/images', label_path='./dataset/labels', new_path='./datasets', split_list=[0.7, 0.2, 0.1], class_names=None, project_name='default_project'): if class_names is None: class_names = {} logger = setup_logger(project_name) delete_folder('dataset', logger) delete_folder('datasets', logger) down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger) process_files_in_folder(img_path, logger) label_count = count_labels_by_class(label_path) logger.info(f"标签统计: {label_count}") split_img(img_path, label_path, split_list, new_path, class_names, logger) base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names) logger.info(f"训练前基线评估: {base_metrics}") delete_folder('dataset', logger) train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger) logger.info("训练流程执行完成") def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name): imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name) auto_train( db_host=db_host, db_database=db_database, db_user=db_user, db_password=db_password, db_port=db_port, model_id=model_id, imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递 label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递 new_path='./datasets', split_list=[0.7, 0.2, 0.1], class_names={'0': 'human', '1': 'car'}, project_name='my_project' ) ####################################### 主入口 ####################################### if __name__ == '__main__': down_and_train( db_host='222.212.85.86', db_database='your_database_name', db_user='postgres', db_password='postgres', db_port='5432', model_id='best.pt', img_path='./dataset/images', #before broken img path label_path='./dataset/labels',#before broken labels path new_path='./datasets', #after broken path split_list=[0.7, 0.2, 0.1], class_names={'0': 'human', '1': 'car'}, project_name='my_project' )