2025-10-09 09:29:18 +08:00

394 lines
14 KiB
Python

""""
main()
|
v
setup_logger(project)
|
v
get_last_model_path(project)
|
v
+-------------------------+
| 有 last.pt | 无 last.pt |
+-------------------------+
| |
v v
load_last_model() start_new_training()
| |
+-------+--------+
|
v
check_dataset(root)
|
v
split_dataset(root, ratios)
|
v
clean_labels(root)
|
v
generate_yaml(dataset_dir)
|
v
train_yolo(model, data_yaml)
|
v
保存 last.pt
|
v
logger.info("Saved last model path")
|
v
写入 logs/{project}.log
"""
import os
import shutil
import datetime
import torch
from ultralytics import YOLO
import random
import math
import stat
import yaml
import psycopg2
from psycopg2 import OperationalError
from collections import Counter
import pandas as pd
import logging
from tqdm import tqdm
import miniohelp as miniohelp
from aboutdataset.download_oss import download_and_save_images_from_oss
######################################## Logging ########################################
def setup_logger(project: str):
os.makedirs("logs", exist_ok=True)
log_file = os.path.join("logs", f"{project}.log")
logger = logging.getLogger(project)
if not logger.handlers:
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
fh = logging.FileHandler(log_file, encoding="utf-8")
fh.setFormatter(formatter)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(sh)
return logger
def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"):
"""
从日志解析上一次训练的 last.pt 路径
如果找不到则返回 default_model
支持 default_model 为 .pt 或 .yaml
"""
log_file = os.path.join("logs", f"{project}.log")
if not os.path.exists(log_file):
return default_model
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in reversed(lines):
if "Saved last model path:" in line:
path = line.strip().split("Saved last model path:")[-1].strip()
if os.path.exists(path):
return path
return default_model
def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name):
"""
yaml_name='config'
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
image_dir='images'
label_dir='labels'
table_name = 'aidataset'
Returns:
(image_dir, label_dir)
"""
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
return image_dir, label_dir
####################################### 工具函数 #######################################
def count_labels_by_class(label_dir):
class_counter = Counter()
for file in os.listdir(label_dir):
if file.endswith('.txt'):
with open(os.path.join(label_dir, file), 'r') as f:
for line in f:
class_id = line.strip().split()[0]
class_counter[class_id] += 1
return dict(class_counter)
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
model = YOLO(model_path)
metrics = model.val(data=dataset_yaml, split='val')
class_ids = range(len(metrics.box.p))
results = {}
for i in class_ids:
name = class_names.get(str(i), str(i))
results[name] = {
"precision": float(metrics.box.p[i]),
"recall": float(metrics.box.r[i]),
"mAP50": float(metrics.box.map50[i]),
"mAP50_95": float(metrics.box.map[i])
}
return results
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
try:
with psycopg2.connect(
database=db_database,
user=db_user,
password=db_password,
host=db_host,
port=db_port
) as conn:
with conn.cursor() as cur:
cur.execute(search_query)
records = cur.fetchall()
return records
except OperationalError as e:
print(f"数据库连接或查询时发生错误: {e}")
except Exception as e:
print(f"发生了其他错误: {e}")
def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger):
search_query = f"SELECT * FROM aidataset WHERE model = '{model}';"
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
if not records:
logger.warning("没有查询到数据。")
return
os.makedirs('./dataset/images', exist_ok=True)
os.makedirs('./dataset/labels', exist_ok=True)
for r in records:
img_path = r[4]
label_content = r[5]
local_img_name = img_path.split('/')[-1]
local_img_path = os.path.join('./dataset/images', local_img_name)
miniohelp.downFile(img_path, local_img_path)
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
txt_path = os.path.join('./dataset/labels', txt_name)
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(label_content + '\n')
logger.info("数据下载完成")
def make_writable(file_path):
os.chmod(file_path, stat.S_IWRITE)
def process_files_in_folder(folder_path, logger):
for root, _, files in os.walk(folder_path):
for file_name in files:
if file_name.endswith('.txt'):
file_path = os.path.join(root, file_name)
make_writable(file_path)
with open(file_path, 'r') as file:
lines = file.readlines()
processed_lines = []
for line in lines:
numbers = line.split()
processed_numbers = []
if numbers[0].isdigit():
processed_numbers.append(numbers[0])
else:
logger.warning(f"Unexpected value in first column: {numbers[0]}")
continue
skip_line = False
for number in numbers[1:]:
try:
number = float(number)
if math.isnan(number):
skip_line = True
logger.warning(f"NaN detected in {file_path}: {line}")
break
if number < 0:
number = abs(number)
processed_numbers.append(str(number))
except ValueError:
processed_numbers.append(number)
if not skip_line:
processed_line = ' '.join(processed_numbers)
processed_lines.append(processed_line)
with open(file_path, 'w') as file:
file.write('\n'.join(processed_lines))
logger.info(f"Processed {file_path}")
def split_img(img_path, label_path, split_list, new_path, class_names, logger):
try:
Data = os.path.abspath(new_path)
os.makedirs(Data, exist_ok=True)
dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels']
for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True)
except Exception as e:
logger.error(f'文件目录创建失败: {e}')
return
train, val, test = split_list
all_img = os.listdir(img_path)
all_img_path = [os.path.join(img_path, img) for img in all_img]
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
train_label = [toLabelPath(img, label_path) for img in train_img]
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
_copy(train_img[i], os.path.join(Data,'train/images'))
_copy(train_label[i], os.path.join(Data,'train/labels'))
all_img_path.remove(train_img[i])
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
val_label = [toLabelPath(img, label_path) for img in val_img]
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
_copy(val_img[i], os.path.join(Data,'val/images'))
_copy(val_label[i], os.path.join(Data,'val/labels'))
all_img_path.remove(val_img[i])
test_img = all_img_path
test_label = [toLabelPath(img, label_path) for img in test_img]
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
_copy(test_img[i], os.path.join(Data,'test/images'))
_copy(test_label[i], os.path.join(Data,'test/labels'))
generate_dataset_yaml(
save_path=os.path.join(Data, 'dataset.yaml'),
train_path=os.path.join(Data,'train/images'),
val_path=os.path.join(Data,'val/images'),
test_path=os.path.join(Data,'test/images'),
class_names=class_names
)
logger.info("数据集划分完成")
def _copy(from_path, to_path):
try:
shutil.copy(from_path, to_path)
except Exception as e:
print(f"复制文件时出错: {e}")
def toLabelPath(img_path, label_path):
img = os.path.basename(img_path)
label = img.replace('.jpg', '.txt')
return os.path.join(label_path, label)
def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names):
dataset_yaml = {
'train': train_path.replace('\\', '/'),
'val': val_path.replace('\\', '/'),
'test': test_path.replace('\\', '/'),
'nc': len(class_names),
'names': list(class_names.values())
}
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(dataset_yaml, f, allow_unicode=True)
def delete_folder(folder_path, logger):
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
logger.info(f"已删除文件夹: {folder_path}")
####################################### 训练 #######################################
def train(project_name, yaml_path, default_model_path, logger):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
model_path = get_last_model_from_log(project_name, default_model_path)
logger.info(f"加载模型: {model_path}")
model = YOLO(model_path).to(device)
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
model.train(
data=yaml_path,
epochs=200,
pretrained=True,
patience=50,
imgsz=640,
device=[0],
workers=0,
project=project_name,
name=current_date,
)
trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt')
if os.path.exists(trained_model_path):
logger.info(f"Saved last model path: {trained_model_path}")
####################################### 自动训练 #######################################
def auto_train(db_host, db_database, db_user, db_password, db_port, model_id,
img_path='./dataset/images', label_path='./dataset/labels',
new_path='./datasets', split_list=[0.7, 0.2, 0.1],
class_names=None, project_name='default_project'):
if class_names is None:
class_names = {}
logger = setup_logger(project_name)
delete_folder('dataset', logger)
delete_folder('datasets', logger)
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger)
process_files_in_folder(img_path, logger)
label_count = count_labels_by_class(label_path)
logger.info(f"标签统计: {label_count}")
split_img(img_path, label_path, split_list, new_path, class_names, logger)
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
logger.info(f"训练前基线评估: {base_metrics}")
delete_folder('dataset', logger)
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger)
logger.info("训练流程执行完成")
def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name):
imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name)
auto_train(
db_host=db_host,
db_database=db_database,
db_user=db_user,
db_password=db_password,
db_port=db_port,
model_id=model_id,
imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递
label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递
new_path='./datasets',
split_list=[0.7, 0.2, 0.1],
class_names={'0': 'human', '1': 'car'},
project_name='my_project'
)
####################################### 主入口 #######################################
if __name__ == '__main__':
down_and_train(
db_host='222.212.85.86',
db_database='your_database_name',
db_user='postgres',
db_password='postgres',
db_port='5432',
model_id='best.pt',
img_path='./dataset/images', #before broken img path
label_path='./dataset/labels',#before broken labels path
new_path='./datasets', #after broken path
split_list=[0.7, 0.2, 0.1],
class_names={'0': 'human', '1': 'car'},
project_name='my_project'
)