370 lines
14 KiB
Python
370 lines
14 KiB
Python
"""
|
||
本程序用以自动训练ai
|
||
1.下载图片与对应的yolo格式标签txt文档
|
||
2.保证文件内容格式正确
|
||
3.分割数据集
|
||
"""
|
||
import os
|
||
import shutil
|
||
import datetime
|
||
import torch
|
||
from ultralytics import YOLO
|
||
import os
|
||
import shutil
|
||
import random
|
||
from tqdm import tqdm
|
||
import math
|
||
import stat
|
||
import yaml
|
||
import os
|
||
import psycopg2
|
||
import miniohelp as miniohelp
|
||
from psycopg2 import OperationalError
|
||
from collections import Counter
|
||
##############################################################################计算每个class的label数量#############################################################
|
||
def count_labels_by_class(label_dir):
|
||
class_counter = Counter()
|
||
for file in os.listdir(label_dir):
|
||
if file.endswith('.txt'):
|
||
with open(os.path.join(label_dir, file), 'r') as f:
|
||
for line in f:
|
||
class_id = line.strip().split()[0]
|
||
class_counter[class_id] += 1
|
||
return dict(class_counter)
|
||
#################################################################################统计原始的召回率和准确率#########################################################
|
||
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
|
||
model = YOLO(model_path)
|
||
metrics = model.val(data=dataset_yaml, split='val')
|
||
|
||
class_ids = range(len(metrics.box.p)) # class id list
|
||
results = {}
|
||
for i in class_ids:
|
||
name = class_names.get(str(i), str(i))
|
||
results[name] = {
|
||
"precision": float(metrics.box.p[i]),
|
||
"recall": float(metrics.box.r[i]),
|
||
"mAP50": float(metrics.box.map50[i]),
|
||
"mAP50_95": float(metrics.box.map[i])
|
||
}
|
||
return results
|
||
|
||
#############################################################################下载图片与对应的yolo格式标签txt文档###################################################
|
||
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
|
||
"""从数据库查询并更新符合条件的记录"""
|
||
try:
|
||
with psycopg2.connect(
|
||
database=db_database,
|
||
user=db_user,
|
||
password=db_password,
|
||
host=db_host,
|
||
port=db_port
|
||
) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(search_query)
|
||
records = cur.fetchall()
|
||
return records
|
||
except OperationalError as e:
|
||
print(f"数据库连接或查询时发生错误: {e}")
|
||
except Exception as e:
|
||
print(f"发生了其他错误: {e}")
|
||
|
||
def down_dataset(db_database, db_user, db_password, db_host, db_port, model):
|
||
'''
|
||
db_database: 数据库名
|
||
db_user: 数据库用户名
|
||
db_password: 数据库密码
|
||
db_host: 数据库地址
|
||
db_port: 数据库端口
|
||
model: 模型ID
|
||
'''
|
||
search_query = f"""
|
||
SELECT * FROM aidataset
|
||
WHERE model = '{model}';
|
||
"""
|
||
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
|
||
if not records:
|
||
print("没有查询到数据。")
|
||
return
|
||
|
||
os.makedirs('./dataset/images', exist_ok=True)
|
||
os.makedirs('./dataset/labels', exist_ok=True)
|
||
|
||
for r in records:
|
||
img_path = r[4]
|
||
label_content = r[5]
|
||
|
||
# 下载图片到 images 文件夹
|
||
local_img_name = img_path.split('/')[-1]
|
||
local_img_path = os.path.join('./dataset/images', local_img_name)
|
||
miniohelp.downFile(img_path, local_img_path)
|
||
|
||
# 写标签到 labels 文件夹
|
||
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
|
||
txt_path = os.path.join('./dataset/labels', txt_name)
|
||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||
f.write(label_content + '\n')
|
||
#############################################################################保证文件内容格式正确###########################################################
|
||
|
||
def make_writable(file_path):
|
||
os.chmod(file_path, stat.S_IWRITE)
|
||
|
||
def process_files_in_folder(folder_path):
|
||
for root, _, files in os.walk(folder_path):
|
||
for file_name in files:
|
||
if file_name.endswith('.txt'):
|
||
file_path = os.path.join(root, file_name)
|
||
|
||
# 确保文件可写
|
||
make_writable(file_path)
|
||
|
||
# 读取文件内容并进行处理
|
||
with open(file_path, 'r') as file:
|
||
lines = file.readlines()
|
||
|
||
processed_lines = []
|
||
for line in lines:
|
||
numbers = line.split()
|
||
processed_numbers = []
|
||
|
||
# 确保第一列为整数 0 或 1,不处理为浮点数
|
||
if numbers[0] == '0' or numbers[0] == '1' or numbers[0] == '2' or numbers[0] == '3' or numbers[0] == '4' or numbers[0] == '5' or numbers[0] == '6' or numbers[0] == '7' or numbers[0] == '8' or numbers[0] == '9':
|
||
processed_numbers.append(numbers[0])
|
||
else:
|
||
print(f"Unexpected value in first column: {numbers[0]}")
|
||
continue
|
||
|
||
# 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据
|
||
skip_line = False # 用于标记是否跳过这一行
|
||
for number in numbers[1:]:
|
||
try:
|
||
number = float(number)
|
||
if math.isnan(number): # 检查是否为NaN
|
||
skip_line = True
|
||
print(f"NaN detected in file: {file_path}, line: {line}")
|
||
break
|
||
if number < 0:
|
||
number = abs(number) # 将负数转换为正数
|
||
processed_numbers.append(str(number)) # 保留原始格式
|
||
except ValueError:
|
||
processed_numbers.append(number) # 非数字列保持原样
|
||
|
||
# 如果该行没有NaN数据,则加入结果列表
|
||
if not skip_line:
|
||
processed_line = ' '.join(processed_numbers)
|
||
processed_lines.append(processed_line)
|
||
|
||
# 将处理后的内容写回文件
|
||
with open(file_path, 'w') as file:
|
||
file.write('\n'.join(processed_lines))
|
||
print(f"Finished processing: {file_path}")
|
||
|
||
###################################################################################分割数据集#####################################################################
|
||
def split_img(img_path, label_path, split_list, new_path, class_names):
|
||
try:
|
||
Data = os.path.abspath(new_path)
|
||
os.makedirs(Data, exist_ok=True)
|
||
|
||
train_img_dir = os.path.abspath(os.path.join(Data, 'train', 'images'))
|
||
val_img_dir = os.path.abspath(os.path.join(Data, 'val', 'images'))
|
||
test_img_dir = os.path.abspath(os.path.join(Data, 'test', 'images'))
|
||
|
||
train_label_dir = os.path.abspath(os.path.join(Data, 'train', 'labels'))
|
||
val_label_dir = os.path.abspath(os.path.join(Data, 'val', 'labels'))
|
||
test_label_dir = os.path.abspath(os.path.join(Data, 'test', 'labels'))
|
||
|
||
os.makedirs(train_img_dir, exist_ok=True)
|
||
os.makedirs(train_label_dir, exist_ok=True)
|
||
os.makedirs(val_img_dir, exist_ok=True)
|
||
os.makedirs(val_label_dir, exist_ok=True)
|
||
os.makedirs(test_img_dir, exist_ok=True)
|
||
os.makedirs(test_label_dir, exist_ok=True)
|
||
|
||
except Exception as e:
|
||
print(f'文件目录创建失败: {e}')
|
||
return
|
||
|
||
train, val, test = split_list
|
||
all_img = os.listdir(img_path)
|
||
all_img_path = [os.path.join(img_path, img) for img in all_img]
|
||
|
||
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
|
||
train_label = [toLabelPath(img, label_path) for img in train_img]
|
||
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
|
||
_copy(train_img[i], train_img_dir)
|
||
_copy(train_label[i], train_label_dir)
|
||
all_img_path.remove(train_img[i])
|
||
|
||
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
|
||
val_label = [toLabelPath(img, label_path) for img in val_img]
|
||
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
|
||
_copy(val_img[i], val_img_dir)
|
||
_copy(val_label[i], val_label_dir)
|
||
all_img_path.remove(val_img[i])
|
||
|
||
test_img = all_img_path
|
||
test_label = [toLabelPath(img, label_path) for img in test_img]
|
||
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
|
||
_copy(test_img[i], test_img_dir)
|
||
_copy(test_label[i], test_label_dir)
|
||
|
||
generate_dataset_yaml(
|
||
save_path=os.path.join(Data, 'dataset.yaml'),
|
||
train_path=train_img_dir,
|
||
val_path=val_img_dir,
|
||
test_path=test_img_dir,
|
||
class_names=class_names
|
||
)
|
||
|
||
def _copy(from_path, to_path):
|
||
try:
|
||
shutil.copy(from_path, to_path)
|
||
except Exception as e:
|
||
print(f"复制文件时出错: {e}")
|
||
|
||
def toLabelPath(img_path, label_path):
|
||
img = os.path.basename(img_path)
|
||
label = img.replace('.jpg', '.txt')
|
||
return os.path.join(label_path, label)
|
||
|
||
def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names):
|
||
dataset_yaml = {
|
||
'train': train_path.replace('\\', '/'),
|
||
'val': val_path.replace('\\', '/'),
|
||
'test': test_path.replace('\\', '/'),
|
||
'nc': len(class_names),
|
||
'names': list(class_names.values())
|
||
}
|
||
|
||
with open(save_path, 'w', encoding='utf-8') as f:
|
||
yaml.dump(dataset_yaml, f, allow_unicode=True)
|
||
|
||
######################################################################开训开训开训##################################################################
|
||
|
||
# 获取上次训练的模型路径
|
||
def get_last_model(default_model_path, log_file='last_model.log'):
|
||
if os.path.exists(log_file):
|
||
with open(log_file, 'r') as f:
|
||
last_model = f.read().strip()
|
||
if last_model and os.path.exists(last_model):
|
||
print(f"使用上次训练模型: {last_model}")
|
||
return last_model
|
||
print(f"使用默认模型: {default_model_path}")
|
||
return default_model_path
|
||
|
||
# 保存本次训练的模型路径
|
||
def save_last_model(model_path, log_file='last_model.log'):
|
||
with open(log_file, 'w') as f:
|
||
f.write(model_path)
|
||
|
||
# 删除文件夹
|
||
def delete_folder(folder_path):
|
||
if os.path.exists(folder_path):
|
||
shutil.rmtree(folder_path)
|
||
print(f"已删除文件夹及其内容: {folder_path}")
|
||
else:
|
||
print(f"文件夹不存在: {folder_path}")
|
||
|
||
# 训练函数
|
||
def train(project_name, yaml_path, default_model_path):
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
model_path = get_last_model(default_model_path)
|
||
model = YOLO(model_path).to(device)
|
||
|
||
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||
project = project_name.strip()
|
||
|
||
model.train(
|
||
data=yaml_path,
|
||
epochs=200,
|
||
pretrained=True,
|
||
patience=50,
|
||
imgsz=640,
|
||
device=[0], # 如果你有多个显卡可以写 [0,1]
|
||
workers=0,
|
||
project=project,
|
||
name=current_date,
|
||
)
|
||
|
||
trained_model_path = os.path.join('runs', 'detect', current_date, 'weights', 'best.pt')
|
||
if os.path.exists(trained_model_path):
|
||
save_last_model(trained_model_path)
|
||
|
||
#######################################################################################主流程##################################################################################
|
||
def auto_train(
|
||
db_host,
|
||
db_database,
|
||
db_user,
|
||
db_password,
|
||
db_port,
|
||
model_id,
|
||
img_path='./dataset/images',
|
||
label_path='./dataset/labels',
|
||
new_path='./datasets',
|
||
split_list=[0.7, 0.2, 0.1],
|
||
class_names=None,
|
||
project_name='default_project'
|
||
):
|
||
if class_names is None:
|
||
class_names = {}
|
||
|
||
# 删除旧数据
|
||
delete_folder('dataset')
|
||
delete_folder('datasets')
|
||
|
||
# 下载新数据
|
||
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id)
|
||
|
||
# 处理 label
|
||
process_files_in_folder(img_path)
|
||
|
||
# 统计标签数量
|
||
label_count = count_labels_by_class(label_path)
|
||
|
||
# 划分数据集
|
||
split_img(img_path, label_path, split_list, new_path, class_names)
|
||
|
||
# 评估训练前模型
|
||
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
|
||
|
||
# 删除原始数据
|
||
delete_folder('dataset')
|
||
|
||
# 开始训练
|
||
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt')
|
||
|
||
# 训练后评估
|
||
best_model_path = f"runs/detect/{project_name}/weights/best.pt"
|
||
final_metrics = evaluate_model_per_class(best_model_path, './datasets/dataset.yaml', class_names)
|
||
|
||
# 删除 datasets
|
||
delete_folder('datasets')
|
||
|
||
print("训练流程执行完成!")
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "Train finished",
|
||
"project_name": project_name,
|
||
"label_count": label_count,
|
||
"base_metrics": base_metrics,
|
||
"final_metrics": final_metrics
|
||
}
|
||
|
||
if __name__ == '__main__':
|
||
auto_train(
|
||
db_host='222.212.85.86',
|
||
db_database='your_database_name',
|
||
db_user='postgres',
|
||
db_password='postgres',
|
||
db_port='5432',
|
||
model_id='best.pt',
|
||
img_path='./dataset/images',
|
||
label_path='./dataset/labels',
|
||
new_path='./datasets',
|
||
split_list=[0.7, 0.2, 0.1],
|
||
class_names={'0': 'human', '1': 'car'},
|
||
project_name='my_project'
|
||
)
|