2025-07-10 10:04:45 +08:00

370 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
本程序用以自动训练ai
1.下载图片与对应的yolo格式标签txt文档
2.保证文件内容格式正确
3.分割数据集
"""
import os
import shutil
import datetime
import torch
from ultralytics import YOLO
import os
import shutil
import random
from tqdm import tqdm
import math
import stat
import yaml
import os
import psycopg2
import miniohelp as miniohelp
from psycopg2 import OperationalError
from collections import Counter
##############################################################################计算每个class的label数量#############################################################
def count_labels_by_class(label_dir):
class_counter = Counter()
for file in os.listdir(label_dir):
if file.endswith('.txt'):
with open(os.path.join(label_dir, file), 'r') as f:
for line in f:
class_id = line.strip().split()[0]
class_counter[class_id] += 1
return dict(class_counter)
#################################################################################统计原始的召回率和准确率#########################################################
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
model = YOLO(model_path)
metrics = model.val(data=dataset_yaml, split='val')
class_ids = range(len(metrics.box.p)) # class id list
results = {}
for i in class_ids:
name = class_names.get(str(i), str(i))
results[name] = {
"precision": float(metrics.box.p[i]),
"recall": float(metrics.box.r[i]),
"mAP50": float(metrics.box.map50[i]),
"mAP50_95": float(metrics.box.map[i])
}
return results
#############################################################################下载图片与对应的yolo格式标签txt文档###################################################
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
"""从数据库查询并更新符合条件的记录"""
try:
with psycopg2.connect(
database=db_database,
user=db_user,
password=db_password,
host=db_host,
port=db_port
) as conn:
with conn.cursor() as cur:
cur.execute(search_query)
records = cur.fetchall()
return records
except OperationalError as e:
print(f"数据库连接或查询时发生错误: {e}")
except Exception as e:
print(f"发生了其他错误: {e}")
def down_dataset(db_database, db_user, db_password, db_host, db_port, model):
'''
db_database: 数据库名
db_user: 数据库用户名
db_password: 数据库密码
db_host: 数据库地址
db_port: 数据库端口
model: 模型ID
'''
search_query = f"""
SELECT * FROM aidataset
WHERE model = '{model}';
"""
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
if not records:
print("没有查询到数据。")
return
os.makedirs('./dataset/images', exist_ok=True)
os.makedirs('./dataset/labels', exist_ok=True)
for r in records:
img_path = r[4]
label_content = r[5]
# 下载图片到 images 文件夹
local_img_name = img_path.split('/')[-1]
local_img_path = os.path.join('./dataset/images', local_img_name)
miniohelp.downFile(img_path, local_img_path)
# 写标签到 labels 文件夹
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
txt_path = os.path.join('./dataset/labels', txt_name)
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(label_content + '\n')
#############################################################################保证文件内容格式正确###########################################################
def make_writable(file_path):
os.chmod(file_path, stat.S_IWRITE)
def process_files_in_folder(folder_path):
for root, _, files in os.walk(folder_path):
for file_name in files:
if file_name.endswith('.txt'):
file_path = os.path.join(root, file_name)
# 确保文件可写
make_writable(file_path)
# 读取文件内容并进行处理
with open(file_path, 'r') as file:
lines = file.readlines()
processed_lines = []
for line in lines:
numbers = line.split()
processed_numbers = []
# 确保第一列为整数 0 或 1不处理为浮点数
if numbers[0] == '0' or numbers[0] == '1' or numbers[0] == '2' or numbers[0] == '3' or numbers[0] == '4' or numbers[0] == '5' or numbers[0] == '6' or numbers[0] == '7' or numbers[0] == '8' or numbers[0] == '9':
processed_numbers.append(numbers[0])
else:
print(f"Unexpected value in first column: {numbers[0]}")
continue
# 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据
skip_line = False # 用于标记是否跳过这一行
for number in numbers[1:]:
try:
number = float(number)
if math.isnan(number): # 检查是否为NaN
skip_line = True
print(f"NaN detected in file: {file_path}, line: {line}")
break
if number < 0:
number = abs(number) # 将负数转换为正数
processed_numbers.append(str(number)) # 保留原始格式
except ValueError:
processed_numbers.append(number) # 非数字列保持原样
# 如果该行没有NaN数据则加入结果列表
if not skip_line:
processed_line = ' '.join(processed_numbers)
processed_lines.append(processed_line)
# 将处理后的内容写回文件
with open(file_path, 'w') as file:
file.write('\n'.join(processed_lines))
print(f"Finished processing: {file_path}")
###################################################################################分割数据集#####################################################################
def split_img(img_path, label_path, split_list, new_path, class_names):
try:
Data = os.path.abspath(new_path)
os.makedirs(Data, exist_ok=True)
train_img_dir = os.path.abspath(os.path.join(Data, 'train', 'images'))
val_img_dir = os.path.abspath(os.path.join(Data, 'val', 'images'))
test_img_dir = os.path.abspath(os.path.join(Data, 'test', 'images'))
train_label_dir = os.path.abspath(os.path.join(Data, 'train', 'labels'))
val_label_dir = os.path.abspath(os.path.join(Data, 'val', 'labels'))
test_label_dir = os.path.abspath(os.path.join(Data, 'test', 'labels'))
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(train_label_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)
os.makedirs(val_label_dir, exist_ok=True)
os.makedirs(test_img_dir, exist_ok=True)
os.makedirs(test_label_dir, exist_ok=True)
except Exception as e:
print(f'文件目录创建失败: {e}')
return
train, val, test = split_list
all_img = os.listdir(img_path)
all_img_path = [os.path.join(img_path, img) for img in all_img]
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
train_label = [toLabelPath(img, label_path) for img in train_img]
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
_copy(train_img[i], train_img_dir)
_copy(train_label[i], train_label_dir)
all_img_path.remove(train_img[i])
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
val_label = [toLabelPath(img, label_path) for img in val_img]
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
_copy(val_img[i], val_img_dir)
_copy(val_label[i], val_label_dir)
all_img_path.remove(val_img[i])
test_img = all_img_path
test_label = [toLabelPath(img, label_path) for img in test_img]
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
_copy(test_img[i], test_img_dir)
_copy(test_label[i], test_label_dir)
generate_dataset_yaml(
save_path=os.path.join(Data, 'dataset.yaml'),
train_path=train_img_dir,
val_path=val_img_dir,
test_path=test_img_dir,
class_names=class_names
)
def _copy(from_path, to_path):
try:
shutil.copy(from_path, to_path)
except Exception as e:
print(f"复制文件时出错: {e}")
def toLabelPath(img_path, label_path):
img = os.path.basename(img_path)
label = img.replace('.jpg', '.txt')
return os.path.join(label_path, label)
def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names):
dataset_yaml = {
'train': train_path.replace('\\', '/'),
'val': val_path.replace('\\', '/'),
'test': test_path.replace('\\', '/'),
'nc': len(class_names),
'names': list(class_names.values())
}
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(dataset_yaml, f, allow_unicode=True)
######################################################################开训开训开训##################################################################
# 获取上次训练的模型路径
def get_last_model(default_model_path, log_file='last_model.log'):
if os.path.exists(log_file):
with open(log_file, 'r') as f:
last_model = f.read().strip()
if last_model and os.path.exists(last_model):
print(f"使用上次训练模型: {last_model}")
return last_model
print(f"使用默认模型: {default_model_path}")
return default_model_path
# 保存本次训练的模型路径
def save_last_model(model_path, log_file='last_model.log'):
with open(log_file, 'w') as f:
f.write(model_path)
# 删除文件夹
def delete_folder(folder_path):
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
print(f"已删除文件夹及其内容: {folder_path}")
else:
print(f"文件夹不存在: {folder_path}")
# 训练函数
def train(project_name, yaml_path, default_model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model_path = get_last_model(default_model_path)
model = YOLO(model_path).to(device)
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
project = project_name.strip()
model.train(
data=yaml_path,
epochs=200,
pretrained=True,
patience=50,
imgsz=640,
device=[0], # 如果你有多个显卡可以写 [0,1]
workers=0,
project=project,
name=current_date,
)
trained_model_path = os.path.join('runs', 'detect', current_date, 'weights', 'best.pt')
if os.path.exists(trained_model_path):
save_last_model(trained_model_path)
#######################################################################################主流程##################################################################################
def auto_train(
db_host,
db_database,
db_user,
db_password,
db_port,
model_id,
img_path='./dataset/images',
label_path='./dataset/labels',
new_path='./datasets',
split_list=[0.7, 0.2, 0.1],
class_names=None,
project_name='default_project'
):
if class_names is None:
class_names = {}
# 删除旧数据
delete_folder('dataset')
delete_folder('datasets')
# 下载新数据
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id)
# 处理 label
process_files_in_folder(img_path)
# 统计标签数量
label_count = count_labels_by_class(label_path)
# 划分数据集
split_img(img_path, label_path, split_list, new_path, class_names)
# 评估训练前模型
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
# 删除原始数据
delete_folder('dataset')
# 开始训练
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt')
# 训练后评估
best_model_path = f"runs/detect/{project_name}/weights/best.pt"
final_metrics = evaluate_model_per_class(best_model_path, './datasets/dataset.yaml', class_names)
# 删除 datasets
delete_folder('datasets')
print("训练流程执行完成!")
return {
"status": "success",
"message": "Train finished",
"project_name": project_name,
"label_count": label_count,
"base_metrics": base_metrics,
"final_metrics": final_metrics
}
if __name__ == '__main__':
auto_train(
db_host='222.212.85.86',
db_database='your_database_name',
db_user='postgres',
db_password='postgres',
db_port='5432',
model_id='best.pt',
img_path='./dataset/images',
label_path='./dataset/labels',
new_path='./datasets',
split_list=[0.7, 0.2, 0.1],
class_names={'0': 'human', '1': 'car'},
project_name='my_project'
)