loggoing
This commit is contained in:
parent
3f3fa0163b
commit
23b990c811
@ -1,28 +1,104 @@
|
|||||||
|
""""
|
||||||
|
main()
|
||||||
|
|
|
||||||
|
v
|
||||||
|
setup_logger(project)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
get_last_model_path(project)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+-------------------------+
|
||||||
|
| 有 last.pt | 无 last.pt |
|
||||||
|
+-------------------------+
|
||||||
|
| |
|
||||||
|
v v
|
||||||
|
load_last_model() start_new_training()
|
||||||
|
| |
|
||||||
|
+-------+--------+
|
||||||
|
|
|
||||||
|
v
|
||||||
|
check_dataset(root)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
split_dataset(root, ratios)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
clean_labels(root)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
generate_yaml(dataset_dir)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
train_yolo(model, data_yaml)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
保存 last.pt
|
||||||
|
|
|
||||||
|
v
|
||||||
|
logger.info("Saved last model path")
|
||||||
|
|
|
||||||
|
v
|
||||||
|
写入 logs/{project}.log
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
本程序用以自动训练ai
|
|
||||||
1.下载图片与对应的yolo格式标签txt文档
|
|
||||||
2.保证文件内容格式正确
|
|
||||||
3.分割数据集
|
|
||||||
"""
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
import torch
|
import torch
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
|
||||||
import math
|
import math
|
||||||
import stat
|
import stat
|
||||||
import yaml
|
import yaml
|
||||||
import os
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import miniohelp as miniohelp
|
|
||||||
from psycopg2 import OperationalError
|
from psycopg2 import OperationalError
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
####################################### 𝓒𝓸𝓶𝓹𝓾𝓽𝓮 𝓵𝓪𝓫𝓮𝓵 𝓬𝓸𝓾𝓷𝓽 𝓯𝓸𝓻 𝓮𝓪𝓬𝓱 𝓬𝓵𝓪𝓼𝓼 ##############################
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
import miniohelp as miniohelp
|
||||||
|
|
||||||
|
######################################## Logging ########################################
|
||||||
|
def setup_logger(project: str):
|
||||||
|
os.makedirs("logs", exist_ok=True)
|
||||||
|
log_file = os.path.join("logs", f"{project}.log")
|
||||||
|
|
||||||
|
logger = logging.getLogger(project)
|
||||||
|
if not logger.handlers:
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
|
||||||
|
|
||||||
|
fh = logging.FileHandler(log_file, encoding="utf-8")
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
sh = logging.StreamHandler()
|
||||||
|
sh.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(fh)
|
||||||
|
logger.addHandler(sh)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def get_last_model_from_log(project: str, default_model: str):
|
||||||
|
"""从日志中解析上一次训练的 last.pt 路径"""
|
||||||
|
log_file = os.path.join("logs", f"{project}.log")
|
||||||
|
if not os.path.exists(log_file):
|
||||||
|
return default_model
|
||||||
|
|
||||||
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
for line in reversed(lines):
|
||||||
|
if "Saved last model path:" in line:
|
||||||
|
path = line.strip().split("Saved last model path:")[-1].strip()
|
||||||
|
if os.path.exists(path):
|
||||||
|
return path
|
||||||
|
return default_model
|
||||||
|
|
||||||
|
####################################### 工具函数 #######################################
|
||||||
def count_labels_by_class(label_dir):
|
def count_labels_by_class(label_dir):
|
||||||
class_counter = Counter()
|
class_counter = Counter()
|
||||||
for file in os.listdir(label_dir):
|
for file in os.listdir(label_dir):
|
||||||
@ -32,12 +108,11 @@ def count_labels_by_class(label_dir):
|
|||||||
class_id = line.strip().split()[0]
|
class_id = line.strip().split()[0]
|
||||||
class_counter[class_id] += 1
|
class_counter[class_id] += 1
|
||||||
return dict(class_counter)
|
return dict(class_counter)
|
||||||
##################################### 𝓢𝓽𝓪𝓽𝓲𝓼𝓽𝓲𝓬𝓼 𝓸𝓯 𝓹𝓻𝓮𝓬𝓲𝓼𝓲𝓸𝓷 𝓪𝓷𝓭 𝓻𝓮𝓬𝓪𝓵𝓵 #####################################
|
|
||||||
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
|
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
|
||||||
model = YOLO(model_path)
|
model = YOLO(model_path)
|
||||||
metrics = model.val(data=dataset_yaml, split='val')
|
metrics = model.val(data=dataset_yaml, split='val')
|
||||||
|
class_ids = range(len(metrics.box.p))
|
||||||
class_ids = range(len(metrics.box.p)) # class id list
|
|
||||||
results = {}
|
results = {}
|
||||||
for i in class_ids:
|
for i in class_ids:
|
||||||
name = class_names.get(str(i), str(i))
|
name = class_names.get(str(i), str(i))
|
||||||
@ -49,10 +124,7 @@ def evaluate_model_per_class(model_path, dataset_yaml, class_names):
|
|||||||
}
|
}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
##################################### 𝓓𝓸𝔀𝓷𝓵𝓸𝓪𝓭 𝓲𝓶𝓪𝓰𝓮𝓼 & 𝔂𝓸𝓵𝓸 𝓽𝓮𝔁𝓽 𝓵𝓪𝓫𝓮𝓵 𝓼########################################
|
|
||||||
|
|
||||||
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
|
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
|
||||||
"""从数据库查询并更新符合条件的记录"""
|
|
||||||
try:
|
try:
|
||||||
with psycopg2.connect(
|
with psycopg2.connect(
|
||||||
database=db_database,
|
database=db_database,
|
||||||
@ -70,22 +142,11 @@ def link_database(db_database, db_user, db_password, db_host, db_port, search_qu
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"发生了其他错误: {e}")
|
print(f"发生了其他错误: {e}")
|
||||||
|
|
||||||
def down_dataset(db_database, db_user, db_password, db_host, db_port, model):
|
def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger):
|
||||||
'''
|
search_query = f"SELECT * FROM aidataset WHERE model = '{model}';"
|
||||||
db_database: 数据库名
|
|
||||||
db_user: 数据库用户名
|
|
||||||
db_password: 数据库密码
|
|
||||||
db_host: 数据库地址
|
|
||||||
db_port: 数据库端口
|
|
||||||
model: 模型ID
|
|
||||||
'''
|
|
||||||
search_query = f"""
|
|
||||||
SELECT * FROM aidataset
|
|
||||||
WHERE model = '{model}';
|
|
||||||
"""
|
|
||||||
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
|
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
|
||||||
if not records:
|
if not records:
|
||||||
print("没有查询到数据。")
|
logger.warning("没有查询到数据。")
|
||||||
return
|
return
|
||||||
|
|
||||||
os.makedirs('./dataset/images', exist_ok=True)
|
os.makedirs('./dataset/images', exist_ok=True)
|
||||||
@ -95,31 +156,27 @@ def down_dataset(db_database, db_user, db_password, db_host, db_port, model):
|
|||||||
img_path = r[4]
|
img_path = r[4]
|
||||||
label_content = r[5]
|
label_content = r[5]
|
||||||
|
|
||||||
# 下载图片到 images 文件夹
|
|
||||||
local_img_name = img_path.split('/')[-1]
|
local_img_name = img_path.split('/')[-1]
|
||||||
local_img_path = os.path.join('./dataset/images', local_img_name)
|
local_img_path = os.path.join('./dataset/images', local_img_name)
|
||||||
miniohelp.downFile(img_path, local_img_path)
|
miniohelp.downFile(img_path, local_img_path)
|
||||||
|
|
||||||
# 写标签到 labels 文件夹
|
|
||||||
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
|
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
|
||||||
txt_path = os.path.join('./dataset/labels', txt_name)
|
txt_path = os.path.join('./dataset/labels', txt_name)
|
||||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(label_content + '\n')
|
f.write(label_content + '\n')
|
||||||
#############################################################################保证文件内容格式正确###########################################################
|
|
||||||
|
logger.info("数据下载完成")
|
||||||
|
|
||||||
def make_writable(file_path):
|
def make_writable(file_path):
|
||||||
os.chmod(file_path, stat.S_IWRITE)
|
os.chmod(file_path, stat.S_IWRITE)
|
||||||
|
|
||||||
def process_files_in_folder(folder_path):
|
def process_files_in_folder(folder_path, logger):
|
||||||
for root, _, files in os.walk(folder_path):
|
for root, _, files in os.walk(folder_path):
|
||||||
for file_name in files:
|
for file_name in files:
|
||||||
if file_name.endswith('.txt'):
|
if file_name.endswith('.txt'):
|
||||||
file_path = os.path.join(root, file_name)
|
file_path = os.path.join(root, file_name)
|
||||||
|
|
||||||
# 确保文件可写
|
|
||||||
make_writable(file_path)
|
make_writable(file_path)
|
||||||
|
|
||||||
# 读取文件内容并进行处理
|
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
lines = file.readlines()
|
lines = file.readlines()
|
||||||
|
|
||||||
@ -127,62 +184,42 @@ def process_files_in_folder(folder_path):
|
|||||||
for line in lines:
|
for line in lines:
|
||||||
numbers = line.split()
|
numbers = line.split()
|
||||||
processed_numbers = []
|
processed_numbers = []
|
||||||
|
if numbers[0].isdigit():
|
||||||
# 确保第一列为整数 0 或 1,不处理为浮点数
|
|
||||||
if numbers[0] == '0' or numbers[0] == '1' or numbers[0] == '2' or numbers[0] == '3' or numbers[0] == '4' or numbers[0] == '5' or numbers[0] == '6' or numbers[0] == '7' or numbers[0] == '8' or numbers[0] == '9':
|
|
||||||
processed_numbers.append(numbers[0])
|
processed_numbers.append(numbers[0])
|
||||||
else:
|
else:
|
||||||
print(f"Unexpected value in first column: {numbers[0]}")
|
logger.warning(f"Unexpected value in first column: {numbers[0]}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据
|
skip_line = False
|
||||||
skip_line = False # 用于标记是否跳过这一行
|
|
||||||
for number in numbers[1:]:
|
for number in numbers[1:]:
|
||||||
try:
|
try:
|
||||||
number = float(number)
|
number = float(number)
|
||||||
if math.isnan(number): # 检查是否为NaN
|
if math.isnan(number):
|
||||||
skip_line = True
|
skip_line = True
|
||||||
print(f"NaN detected in file: {file_path}, line: {line}")
|
logger.warning(f"NaN detected in {file_path}: {line}")
|
||||||
break
|
break
|
||||||
if number < 0:
|
if number < 0:
|
||||||
number = abs(number) # 将负数转换为正数
|
number = abs(number)
|
||||||
processed_numbers.append(str(number)) # 保留原始格式
|
processed_numbers.append(str(number))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
processed_numbers.append(number) # 非数字列保持原样
|
processed_numbers.append(number)
|
||||||
|
|
||||||
# 如果该行没有NaN数据,则加入结果列表
|
|
||||||
if not skip_line:
|
if not skip_line:
|
||||||
processed_line = ' '.join(processed_numbers)
|
processed_line = ' '.join(processed_numbers)
|
||||||
processed_lines.append(processed_line)
|
processed_lines.append(processed_line)
|
||||||
|
|
||||||
# 将处理后的内容写回文件
|
|
||||||
with open(file_path, 'w') as file:
|
with open(file_path, 'w') as file:
|
||||||
file.write('\n'.join(processed_lines))
|
file.write('\n'.join(processed_lines))
|
||||||
print(f"Finished processing: {file_path}")
|
logger.info(f"Processed {file_path}")
|
||||||
|
|
||||||
###################################################################################分割数据集#####################################################################
|
def split_img(img_path, label_path, split_list, new_path, class_names, logger):
|
||||||
def split_img(img_path, label_path, split_list, new_path, class_names):
|
|
||||||
try:
|
try:
|
||||||
Data = os.path.abspath(new_path)
|
Data = os.path.abspath(new_path)
|
||||||
os.makedirs(Data, exist_ok=True)
|
os.makedirs(Data, exist_ok=True)
|
||||||
|
dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels']
|
||||||
train_img_dir = os.path.abspath(os.path.join(Data, 'train', 'images'))
|
for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True)
|
||||||
val_img_dir = os.path.abspath(os.path.join(Data, 'val', 'images'))
|
|
||||||
test_img_dir = os.path.abspath(os.path.join(Data, 'test', 'images'))
|
|
||||||
|
|
||||||
train_label_dir = os.path.abspath(os.path.join(Data, 'train', 'labels'))
|
|
||||||
val_label_dir = os.path.abspath(os.path.join(Data, 'val', 'labels'))
|
|
||||||
test_label_dir = os.path.abspath(os.path.join(Data, 'test', 'labels'))
|
|
||||||
|
|
||||||
os.makedirs(train_img_dir, exist_ok=True)
|
|
||||||
os.makedirs(train_label_dir, exist_ok=True)
|
|
||||||
os.makedirs(val_img_dir, exist_ok=True)
|
|
||||||
os.makedirs(val_label_dir, exist_ok=True)
|
|
||||||
os.makedirs(test_img_dir, exist_ok=True)
|
|
||||||
os.makedirs(test_label_dir, exist_ok=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'文件目录创建失败: {e}')
|
logger.error(f'文件目录创建失败: {e}')
|
||||||
return
|
return
|
||||||
|
|
||||||
train, val, test = split_list
|
train, val, test = split_list
|
||||||
@ -192,30 +229,31 @@ def split_img(img_path, label_path, split_list, new_path, class_names):
|
|||||||
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
|
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
|
||||||
train_label = [toLabelPath(img, label_path) for img in train_img]
|
train_label = [toLabelPath(img, label_path) for img in train_img]
|
||||||
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
|
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
|
||||||
_copy(train_img[i], train_img_dir)
|
_copy(train_img[i], os.path.join(Data,'train/images'))
|
||||||
_copy(train_label[i], train_label_dir)
|
_copy(train_label[i], os.path.join(Data,'train/labels'))
|
||||||
all_img_path.remove(train_img[i])
|
all_img_path.remove(train_img[i])
|
||||||
|
|
||||||
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
|
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
|
||||||
val_label = [toLabelPath(img, label_path) for img in val_img]
|
val_label = [toLabelPath(img, label_path) for img in val_img]
|
||||||
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
|
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
|
||||||
_copy(val_img[i], val_img_dir)
|
_copy(val_img[i], os.path.join(Data,'val/images'))
|
||||||
_copy(val_label[i], val_label_dir)
|
_copy(val_label[i], os.path.join(Data,'val/labels'))
|
||||||
all_img_path.remove(val_img[i])
|
all_img_path.remove(val_img[i])
|
||||||
|
|
||||||
test_img = all_img_path
|
test_img = all_img_path
|
||||||
test_label = [toLabelPath(img, label_path) for img in test_img]
|
test_label = [toLabelPath(img, label_path) for img in test_img]
|
||||||
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
|
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
|
||||||
_copy(test_img[i], test_img_dir)
|
_copy(test_img[i], os.path.join(Data,'test/images'))
|
||||||
_copy(test_label[i], test_label_dir)
|
_copy(test_label[i], os.path.join(Data,'test/labels'))
|
||||||
|
|
||||||
generate_dataset_yaml(
|
generate_dataset_yaml(
|
||||||
save_path=os.path.join(Data, 'dataset.yaml'),
|
save_path=os.path.join(Data, 'dataset.yaml'),
|
||||||
train_path=train_img_dir,
|
train_path=os.path.join(Data,'train/images'),
|
||||||
val_path=val_img_dir,
|
val_path=os.path.join(Data,'val/images'),
|
||||||
test_path=test_img_dir,
|
test_path=os.path.join(Data,'test/images'),
|
||||||
class_names=class_names
|
class_names=class_names
|
||||||
)
|
)
|
||||||
|
logger.info("数据集划分完成")
|
||||||
|
|
||||||
def _copy(from_path, to_path):
|
def _copy(from_path, to_path):
|
||||||
try:
|
try:
|
||||||
@ -236,162 +274,71 @@ def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_name
|
|||||||
'nc': len(class_names),
|
'nc': len(class_names),
|
||||||
'names': list(class_names.values())
|
'names': list(class_names.values())
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(save_path, 'w', encoding='utf-8') as f:
|
with open(save_path, 'w', encoding='utf-8') as f:
|
||||||
yaml.dump(dataset_yaml, f, allow_unicode=True)
|
yaml.dump(dataset_yaml, f, allow_unicode=True)
|
||||||
|
|
||||||
######################################################################开训开训开训##################################################################
|
def delete_folder(folder_path, logger):
|
||||||
|
|
||||||
# 获取上次训练的模型路径
|
|
||||||
def get_last_model(default_model_path, log_file='last_model.log'):
|
|
||||||
if os.path.exists(log_file):
|
|
||||||
with open(log_file, 'r') as f:
|
|
||||||
last_model = f.read().strip()
|
|
||||||
if last_model and os.path.exists(last_model):
|
|
||||||
print(f"使用上次训练模型: {last_model}")
|
|
||||||
return last_model
|
|
||||||
print(f"使用默认模型: {default_model_path}")
|
|
||||||
return default_model_path
|
|
||||||
|
|
||||||
# 保存本次训练的模型路径
|
|
||||||
def save_last_model(model_path, log_file='last_model.log'):
|
|
||||||
with open(log_file, 'w') as f:
|
|
||||||
f.write(model_path)
|
|
||||||
|
|
||||||
# 删除文件夹
|
|
||||||
def delete_folder(folder_path):
|
|
||||||
if os.path.exists(folder_path):
|
if os.path.exists(folder_path):
|
||||||
shutil.rmtree(folder_path)
|
shutil.rmtree(folder_path)
|
||||||
print(f"已删除文件夹及其内容: {folder_path}")
|
logger.info(f"已删除文件夹: {folder_path}")
|
||||||
else:
|
|
||||||
print(f"文件夹不存在: {folder_path}")
|
|
||||||
|
|
||||||
# 训练函数
|
####################################### 训练 #######################################
|
||||||
def train(project_name, yaml_path, default_model_path):
|
def train(project_name, yaml_path, default_model_path, logger):
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
model_path = get_last_model(default_model_path)
|
model_path = get_last_model_from_log(project_name, default_model_path)
|
||||||
|
logger.info(f"加载模型: {model_path}")
|
||||||
model = YOLO(model_path).to(device)
|
model = YOLO(model_path).to(device)
|
||||||
|
|
||||||
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||||||
project = project_name.strip()
|
|
||||||
|
|
||||||
model.train(
|
model.train(
|
||||||
data=yaml_path,
|
data=yaml_path,
|
||||||
epochs=200,
|
epochs=200,
|
||||||
pretrained=True,
|
pretrained=True,
|
||||||
patience=50,
|
patience=50,
|
||||||
imgsz=640,
|
imgsz=640,
|
||||||
device=[0], # 如果你有多个显卡可以写 [0,1]
|
device=[0],
|
||||||
workers=0,
|
workers=0,
|
||||||
project=project,
|
project=project_name,
|
||||||
name=current_date,
|
name=current_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
trained_model_path = os.path.join('runs', 'detect', current_date, 'weights', 'best.pt')
|
trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt')
|
||||||
if os.path.exists(trained_model_path):
|
if os.path.exists(trained_model_path):
|
||||||
save_last_model(trained_model_path)
|
logger.info(f"Saved last model path: {trained_model_path}")
|
||||||
|
|
||||||
#######################################################################################自动训练##################################################################################
|
####################################### 自动训练 #######################################
|
||||||
def auto_train(
|
def auto_train(db_host, db_database, db_user, db_password, db_port, model_id,
|
||||||
db_host,
|
img_path='./dataset/images', label_path='./dataset/labels',
|
||||||
db_database,
|
new_path='./datasets', split_list=[0.7, 0.2, 0.1],
|
||||||
db_user,
|
class_names=None, project_name='default_project'):
|
||||||
db_password,
|
|
||||||
db_port,
|
|
||||||
model_id,
|
|
||||||
img_path='./dataset/images',
|
|
||||||
label_path='./dataset/labels',
|
|
||||||
new_path='./datasets',
|
|
||||||
split_list=[0.7, 0.2, 0.1],
|
|
||||||
class_names=None,
|
|
||||||
project_name='default_project'
|
|
||||||
):
|
|
||||||
if class_names is None:
|
if class_names is None:
|
||||||
class_names = {}
|
class_names = {}
|
||||||
|
|
||||||
# 删除旧数据
|
logger = setup_logger(project_name)
|
||||||
delete_folder('dataset')
|
|
||||||
delete_folder('datasets')
|
|
||||||
|
|
||||||
# 下载新数据
|
delete_folder('dataset', logger)
|
||||||
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id)
|
delete_folder('datasets', logger)
|
||||||
|
|
||||||
# 处理 label
|
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger)
|
||||||
process_files_in_folder(img_path)
|
process_files_in_folder(img_path, logger)
|
||||||
|
|
||||||
# 统计标签数量
|
|
||||||
label_count = count_labels_by_class(label_path)
|
label_count = count_labels_by_class(label_path)
|
||||||
|
logger.info(f"标签统计: {label_count}")
|
||||||
|
|
||||||
# 划分数据集
|
split_img(img_path, label_path, split_list, new_path, class_names, logger)
|
||||||
split_img(img_path, label_path, split_list, new_path, class_names)
|
|
||||||
|
|
||||||
# 评估训练前模型
|
|
||||||
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
|
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
|
||||||
|
logger.info(f"训练前基线评估: {base_metrics}")
|
||||||
|
|
||||||
# 删除原始数据
|
delete_folder('dataset', logger)
|
||||||
delete_folder('dataset')
|
|
||||||
|
|
||||||
# 开始训练
|
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger)
|
||||||
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt')
|
|
||||||
|
|
||||||
# 训练后评估
|
logger.info("训练流程执行完成")
|
||||||
best_model_path = f"runs/detect/{project_name}/weights/best.pt"
|
|
||||||
final_metrics = evaluate_model_per_class(best_model_path, './datasets/dataset.yaml', class_names)
|
|
||||||
|
|
||||||
# 删除 datasets
|
|
||||||
delete_folder('datasets')
|
|
||||||
|
|
||||||
print("训练流程执行完成!")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "Train finished",
|
|
||||||
"project_name": project_name,
|
|
||||||
"label_count": label_count,
|
|
||||||
"base_metrics": base_metrics,
|
|
||||||
"final_metrics": final_metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
##########################################################################################get the training ##################################################################
|
|
||||||
def query_progress(project_name, run_time=None):
|
|
||||||
"""
|
|
||||||
查询训练进度
|
|
||||||
:param project_name: 训练工程名
|
|
||||||
:param run_time: 启动时间戳(默认取最近一次)
|
|
||||||
:return: 当前进度信息
|
|
||||||
"""
|
|
||||||
base_dir = os.path.join("runs", "detect", project_name)
|
|
||||||
if not os.path.exists(base_dir):
|
|
||||||
return {"status": "not_found", "message": f"没有找到 {base_dir}"}
|
|
||||||
|
|
||||||
# 如果没有指定 run_time,取最新目录
|
|
||||||
if run_time is None:
|
|
||||||
dirs = sorted(os.listdir(base_dir), reverse=True)
|
|
||||||
if not dirs:
|
|
||||||
# 𝒟𝑜𝒸𝓊𝓂𝑒𝓃𝓉𝒶 𝑒𝓍𝑒𝓇𝒸𝒾𝓉𝒶𝓉𝒾𝑜𝓃𝒾𝓈 𝓃𝑜𝓃 𝒾𝓃𝓋𝑒𝓃𝓉𝒶 𝓈𝓊𝓃𝓉
|
|
||||||
return {"status": "not_found", "message": "Documenta exercitationis adhuc inveniuntur non potest"}
|
|
||||||
run_time = dirs[0]
|
|
||||||
|
|
||||||
log_path = os.path.join(base_dir, run_time, "results.csv")
|
|
||||||
if not os.path.exists(log_path):
|
|
||||||
return {"status": "not_found", "message": f"Documenta exercitationis non inventa sunt {log_path}"}
|
|
||||||
|
|
||||||
df = pd.read_csv(log_path)
|
|
||||||
if df.empty:
|
|
||||||
return {"status": "running", "message": "日志尚未写入数据"}
|
|
||||||
|
|
||||||
last_row = df.iloc[-1].to_dict()
|
|
||||||
progress = {
|
|
||||||
"epoch": int(last_row.get("epoch", -1)),
|
|
||||||
"precision": float(last_row.get("precision", 0.0)),
|
|
||||||
"recall": float(last_row.get("recall", 0.0)),
|
|
||||||
"mAP50": float(last_row.get("mAP50", 0.0)),
|
|
||||||
"mAP50-95": float(last_row.get("mAP50-95", 0.0)),
|
|
||||||
}
|
|
||||||
return {"status": "ok", "run_time": run_time, "progress": progress}
|
|
||||||
|
|
||||||
|
####################################### 主入口 #######################################
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
auto_train(
|
auto_train(
|
||||||
db_host='222.212.85.86',
|
db_host='222.212.85.86',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user