yoooooger

This commit is contained in:
yooooger 2025-11-11 09:43:25 +08:00
parent 49f00429b9
commit 04c0b96f33
8 changed files with 482 additions and 452 deletions

View File

@ -7,8 +7,15 @@ from sanic_cors import CORS
# ourself imports
from ai_image import process_images
from map_find import map_process_images
from yolo_train import auto_train,query_progress
from yolo_train import train_main
from yolo_photo import map_process_images_with_progress
from pydantic import BaseModel, ValidationError
from typing import List, Dict
import threading
import torch
import uuid
from queue import Queue
# set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@ -237,83 +244,137 @@ async def yolo_detect_api(request):
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO auto_train API
#--------------------------------------------------------------------------yolo训练相关的API----------------------------------------------------------------########################################
#创建yolo训练的蓝图
MAX_CONCURRENT_JOBS = torch.cuda.device_count() if torch.cuda.is_available() else 1
tasks: Dict[str, Dict] = {}
task_queue = Queue()
active_jobs: List[str] = []
lock = threading.Lock()
# ------------------ 参数模型 ------------------
class TrainRequest(BaseModel):
config_name: str
table_name: str
column_name: str
search_condition: str
aim_path: str
image_dir: str
label_dir: str
output_path: str
pt_path: str
imgsz: int
epochs: int
device: List[int]
hsv_v: float
cos_lr: bool
batch: int
project_dir: str
class_names: List[str]
# ------------------ 核心执行函数 ------------------
def run_training(task_id: str, params: TrainRequest):
try:
with lock:
active_jobs.append(task_id)
tasks[task_id]["status"] = "running"
train_main(
config_name=params.config_name,
table_name=params.table_name,
column_name=params.column_name,
search_condition=params.search_condition,
aim_path=params.aim_path,
image_dir=params.image_dir,
label_dir=params.label_dir,
output_path=params.output_path,
pt_path=params.pt_path,
imgsz=params.imgsz,
epochs=params.epochs,
device=params.device,
hsv_v=params.hsv_v,
cos_lr=params.cos_lr,
batch=params.batch,
project_dir=params.project_dir,
class_names=params.class_names
)
tasks[task_id]["status"] = "finished"
except Exception as e:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = str(e)
finally:
with lock:
if task_id in active_jobs:
active_jobs.remove(task_id)
schedule_next_job()
# ------------------ 调度器 ------------------
def schedule_next_job():
with lock:
while len(active_jobs) < MAX_CONCURRENT_JOBS and not task_queue.empty():
next_id = task_queue.get()
params = tasks[next_id]["params"]
t = threading.Thread(target=run_training, args=(next_id, params), daemon=True)
t.start()
# ------------------ 接口 ------------------
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
auto_train
input JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
output JSON:
return {
"status": "success",
"message": "Train finished",
"project_name": project_name,
"label_count": label_count,
"base_metrics": base_metrics,
"final_metrics": final_metrics
}
"""
async def submit_train_job(request):
try:
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# Do the training in a separate thread to avoid blocking the event loop
result = await asyncio.to_thread(
auto_train,
data
)
# return the result as JSON response
return json_response(result)
params = TrainRequest(**data)
except ValidationError as e:
return json({"success": False, "error": e.errors()})
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
task_id = str(uuid.uuid4())
tasks[task_id] = {"status": "queued", "params": params}
# access the training progress
@yolo_tile_blueprint.get("/progress/<project_name>")
async def yolo_train_progress(request, project_name):
'''
input:
if want to query the latest progress: GET /yolo/progress/my_project
if want to query the progress at a specific time: GET /yolo/progress/my_project?run_time=20250902_1012
output JSON:
{
"status": "ok",
"run_time": "20250902_1012",
"progress": {
"epoch": 12,
"precision": 0.72,
"recall": 0.64,
"mAP50": 0.68,
"mAP50-95": 0.42
}
}
'''
run_time = request.args.get("run_time") # get the run_time from the query string
# query the progress from the database
if not run_time:
run_time = None # if not provided, query the latest progress
with lock:
if len(active_jobs) < MAX_CONCURRENT_JOBS:
t = threading.Thread(target=run_training, args=(task_id, params), daemon=True)
t.start()
else:
task_queue.put(task_id)
tasks[task_id]["status"] = "waiting"
result = await asyncio.to_thread(query_progress, project_name, run_time)
return json_response(result)
return json({"success": True, "task_id": task_id, "message": "任务已提交"})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id: str):
if task_id not in tasks:
return json({"success": False, "message": "任务ID不存在"})
task_info = tasks[task_id]
return json({
"success": True,
"status": task_info["status"],
"error": task_info.get("error", None)
})
@yolo_tile_blueprint.get("/tasks")
async def all_tasks(request):
return json({
tid: {"status": info["status"]}
for tid, info in tasks.items()
})
@yolo_tile_blueprint.get("/system_status")
async def system_status(request):
gpu_available = torch.cuda.is_available()
return json({
"gpu_available": gpu_available,
"max_concurrent": MAX_CONCURRENT_JOBS,
"running_jobs": len(active_jobs),
"waiting_jobs": task_queue.qsize(),
"active_task_ids": active_jobs
})
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)

View File

@ -8,6 +8,6 @@ minio:
sql:
host: '222.212.85.86'
port: 5432
dbname: 'postgres'
dbname: 'smart_dev'
user: 'postgres'
password: 'root'

87
Ai_tottle/train/broken.py Normal file
View File

@ -0,0 +1,87 @@
import os
import shutil
import random
from tqdm import tqdm
import yaml
def split_img(img_path, label_path, split_list, output_path,class_names=[
'people',
'car',
'truck',
'bicycle',
'tricycle',
'ship']):
try:
# 创建目标目录结构
for sub in ['images/train', 'images/val', 'images/test',
'labels/train', 'labels/val', 'labels/test']:
os.makedirs(os.path.join(output_path, sub), exist_ok=True)
except Exception as e:
print(f'❌ 文件目录创建失败: {e}')
return
train, val, test = split_list
all_imgs = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.png'))]
all_img_paths = [os.path.join(img_path, f) for f in all_imgs]
# 分配训练集
train_imgs = random.sample(all_img_paths, int(train * len(all_img_paths)))
move_set(train_imgs, label_path, os.path.join(output_path, 'images/train'), os.path.join(output_path, 'labels/train'))
for f in train_imgs: all_img_paths.remove(f)
# 分配验证集
val_imgs = random.sample(all_img_paths, int(val / (val + test) * len(all_img_paths)))
move_set(val_imgs, label_path, os.path.join(output_path, 'images/val'), os.path.join(output_path, 'labels/val'))
for f in val_imgs: all_img_paths.remove(f)
# 剩余分配给测试集
test_imgs = all_img_paths
move_set(test_imgs, label_path, os.path.join(output_path, 'images/test'), os.path.join(output_path, 'labels/test'))
# 生成 dataset.yaml
generate_yaml(output_path, class_names)
def move_set(img_list, label_root, dst_img_dir, dst_label_dir):
for img_path in tqdm(img_list, desc=f'Copying to {os.path.basename(dst_img_dir)}', ncols=80):
base = os.path.splitext(os.path.basename(img_path))[0]
label_path = os.path.join(label_root, base + '.txt')
shutil.copy(img_path, os.path.join(dst_img_dir, os.path.basename(img_path)))
if os.path.exists(label_path):
shutil.copy(label_path, os.path.join(dst_label_dir, base + '.txt'))
def generate_yaml(dataset_root, class_names):
yaml_content = {
'train': os.path.join('images/train'),
'val': os.path.join('images/val'),
'test': os.path.join('images/test'),
'nc': len(class_names),
'names': class_names
}
with open(os.path.join(dataset_root, 'dataset.yaml'), 'w') as f:
yaml.dump(yaml_content, f, default_flow_style=False)
print(f"✅ 已生成 YAML: {os.path.join(dataset_root, 'dataset.yaml')}")
def broken_main(aim_path, output_path,class_names=[
'people',
'car',
'truck',
'bicycle',
'tricycle',
'ship']):
img_path = os.path.join(aim_path, 'images')
label_path = os.path.join(aim_path, 'labels')
split_ratio = [0.7, 0.2, 0.1]
split_img(img_path, label_path, split_ratio, output_path,class_names)
if __name__ == '__main__':
broken_main(
r"D:\Users\76118\Downloads\stanford_campus_dataset\filtered",
r"D:\work\develop\AI\数据集\output",
class_names=[
'people',
'car',
'truck',
'bicycle',
'tricycle',]
)

View File

@ -0,0 +1,85 @@
import os
import stat
import math
def make_writable(file_path):
os.chmod(file_path, stat.S_IWRITE)
def process_files_in_folder(folder_path):
for root, _, files in os.walk(folder_path):
for file_name in files:
if file_name.endswith(".txt"):
file_path = os.path.join(root, file_name)
# 确保文件可写
make_writable(file_path)
# 读取文件内容并进行处理
with open(file_path, "r") as file:
lines = file.readlines()
processed_lines = []
for line in lines:
numbers = line.split()
processed_numbers = []
# 确保第一列为整数 0 或 1不处理为浮点数
if (
numbers[0] == "0"
or numbers[0] == "1"
or numbers[0] == "2"
or numbers[0] == "3"
or numbers[0] == "4"
or numbers[0] == "5"
or numbers[0] == "6"
or numbers[0] == "7"
or numbers[0] == "8"
or numbers[0] == "9"
or numbers[0] == "10"
or numbers[0] == "11"
or numbers[0] == "12"
or numbers[0] == "13"
or numbers[0] == "14"
or numbers[0] == "15"
or numbers[0] == "16"
or numbers[0] == "17"
or numbers[0] == "18"
):
processed_numbers.append(numbers[0])
else:
print(f"Unexpected value in first column: {numbers[0]}")
continue
# 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据
skip_line = False # 用于标记是否跳过这一行
for number in numbers[1:]:
try:
number = float(number)
if math.isnan(number): # 检查是否为NaN
skip_line = True
print(
f"NaN detected in file: {file_path}, line: {line}"
)
break
if number < 0:
number = abs(number) # 将负数转换为正数
processed_numbers.append(str(number)) # 保留原始格式
except ValueError:
processed_numbers.append(number) # 非数字列保持原样
# 如果该行没有NaN数据则加入结果列表
if not skip_line:
processed_line = " ".join(processed_numbers)
processed_lines.append(processed_line)
# 将处理后的内容写回文件
with open(file_path, "w") as file:
file.write("\n".join(processed_lines))
print(f"Finished processing: {file_path}")
# 指定文件夹路径
folder_path = r"G:\dataset\PCS\before\labels"
#run the function
process_files_in_folder(folder_path)

25
Ai_tottle/train/train.py Normal file
View File

@ -0,0 +1,25 @@
from ultralytics import YOLO
import torch
# 检查CUDA是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 加载模型
model = YOLO("runs/detect/train6/weights/last.pt").to(device)
# 设置新的分辨率
imgsz = 1024 # 这里将图像尺寸调整为 1280x1280你可以根据显存调整尺寸
# 训练模型,传入增强参数
model.train(
data="dataset/dataset.yaml", # 你的数据集配置文件
epochs=1000, # 训练轮次
imgsz=imgsz, # 使用更高的分辨率
device=[1], # 使用第一块 GPU如果有多个 GPU可以调整
hsv_v=0.3, # 修改图像亮度的一部分,帮助模型在不同光照条件下表现良好
cos_lr=True, # 启用余弦学习率调度
batch = -1, # 自动调整批量大小以适应显存
)

View File

@ -1,394 +1,166 @@
""""
main()
|
v
setup_logger(project)
|
v
get_last_model_path(project)
|
v
+-------------------------+
| last.pt | last.pt |
+-------------------------+
| |
v v
load_last_model() start_new_training()
| |
+-------+--------+
|
v
check_dataset(root)
|
v
split_dataset(root, ratios)
|
v
clean_labels(root)
|
v
generate_yaml(dataset_dir)
|
v
train_yolo(model, data_yaml)
|
v
保存 last.pt
|
v
logger.info("Saved last model path")
|
v
写入 logs/{project}.log
"""
import os
import shutil
import datetime
import torch
from ultralytics import YOLO
import random
import math
import stat
import yaml
import psycopg2
from psycopg2 import OperationalError
from collections import Counter
import pandas as pd
import logging
from tqdm import tqdm
import miniohelp as miniohelp
from glob import glob
from aboutdataset.download_oss import download_and_save_images_from_oss
from train.let_txt_to_true import process_files_in_folder
from train.broken import broken_main
from ultralytics import YOLO
import torch
######################################## Logging ########################################
def setup_logger(project: str):
os.makedirs("logs", exist_ok=True)
log_file = os.path.join("logs", f"{project}.log")
logger = logging.getLogger(project)
if not logger.handlers:
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
# ------------------ 下载图片和标签 ------------------
def download_images_and_labels(
config_name, # OSS 配置文件名,用于读取连接信息
table_name, # OSS 表名,指定下载数据的表
column_name, # OSS 表中图片 URL 列名
search_condition, # 筛选条件,用于查询 OSS 数据
aim_path, # 本地保存数据集根目录
image_dir, # 本地保存图片的目录
label_dir # 本地保存标签 txt 的目录
):
os.makedirs(aim_path, exist_ok=True)
os.makedirs(image_dir, exist_ok=True)
os.makedirs(label_dir, exist_ok=True)
fh = logging.FileHandler(log_file, encoding="utf-8")
fh.setFormatter(formatter)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(sh)
return logger
def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"):
"""
从日志解析上一次训练的 last.pt 路径
如果找不到则返回 default_model
支持 default_model .pt .yaml
"""
log_file = os.path.join("logs", f"{project}.log")
if not os.path.exists(log_file):
return default_model
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in reversed(lines):
if "Saved last model path:" in line:
path = line.strip().split("Saved last model path:")[-1].strip()
if os.path.exists(path):
return path
return default_model
def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name):
"""
yaml_name='config'
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
image_dir='images'
label_dir='labels'
table_name = 'aidataset'
Returns:
(image_dir, label_dir)
"""
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
return image_dir, label_dir
####################################### 工具函数 #######################################
def count_labels_by_class(label_dir):
class_counter = Counter()
for file in os.listdir(label_dir):
if file.endswith('.txt'):
with open(os.path.join(label_dir, file), 'r') as f:
for line in f:
class_id = line.strip().split()[0]
class_counter[class_id] += 1
return dict(class_counter)
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
model = YOLO(model_path)
metrics = model.val(data=dataset_yaml, split='val')
class_ids = range(len(metrics.box.p))
results = {}
for i in class_ids:
name = class_names.get(str(i), str(i))
results[name] = {
"precision": float(metrics.box.p[i]),
"recall": float(metrics.box.r[i]),
"mAP50": float(metrics.box.map50[i]),
"mAP50_95": float(metrics.box.map[i])
}
return results
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
try:
with psycopg2.connect(
database=db_database,
user=db_user,
password=db_password,
host=db_host,
port=db_port
) as conn:
with conn.cursor() as cur:
cur.execute(search_query)
records = cur.fetchall()
return records
except OperationalError as e:
print(f"数据库连接或查询时发生错误: {e}")
except Exception as e:
print(f"发生了其他错误: {e}")
def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger):
search_query = f"SELECT * FROM aidataset WHERE model = '{model}';"
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
if not records:
logger.warning("没有查询到数据。")
return
os.makedirs('./dataset/images', exist_ok=True)
os.makedirs('./dataset/labels', exist_ok=True)
for r in records:
img_path = r[4]
label_content = r[5]
local_img_name = img_path.split('/')[-1]
local_img_path = os.path.join('./dataset/images', local_img_name)
miniohelp.downFile(img_path, local_img_path)
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
txt_path = os.path.join('./dataset/labels', txt_name)
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(label_content + '\n')
logger.info("数据下载完成")
def make_writable(file_path):
os.chmod(file_path, stat.S_IWRITE)
def process_files_in_folder(folder_path, logger):
for root, _, files in os.walk(folder_path):
for file_name in files:
if file_name.endswith('.txt'):
file_path = os.path.join(root, file_name)
make_writable(file_path)
with open(file_path, 'r') as file:
lines = file.readlines()
processed_lines = []
for line in lines:
numbers = line.split()
processed_numbers = []
if numbers[0].isdigit():
processed_numbers.append(numbers[0])
else:
logger.warning(f"Unexpected value in first column: {numbers[0]}")
continue
skip_line = False
for number in numbers[1:]:
try:
number = float(number)
if math.isnan(number):
skip_line = True
logger.warning(f"NaN detected in {file_path}: {line}")
break
if number < 0:
number = abs(number)
processed_numbers.append(str(number))
except ValueError:
processed_numbers.append(number)
if not skip_line:
processed_line = ' '.join(processed_numbers)
processed_lines.append(processed_line)
with open(file_path, 'w') as file:
file.write('\n'.join(processed_lines))
logger.info(f"Processed {file_path}")
def split_img(img_path, label_path, split_list, new_path, class_names, logger):
try:
Data = os.path.abspath(new_path)
os.makedirs(Data, exist_ok=True)
dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels']
for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True)
except Exception as e:
logger.error(f'文件目录创建失败: {e}')
return
train, val, test = split_list
all_img = os.listdir(img_path)
all_img_path = [os.path.join(img_path, img) for img in all_img]
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
train_label = [toLabelPath(img, label_path) for img in train_img]
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
_copy(train_img[i], os.path.join(Data,'train/images'))
_copy(train_label[i], os.path.join(Data,'train/labels'))
all_img_path.remove(train_img[i])
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
val_label = [toLabelPath(img, label_path) for img in val_img]
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
_copy(val_img[i], os.path.join(Data,'val/images'))
_copy(val_label[i], os.path.join(Data,'val/labels'))
all_img_path.remove(val_img[i])
test_img = all_img_path
test_label = [toLabelPath(img, label_path) for img in test_img]
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
_copy(test_img[i], os.path.join(Data,'test/images'))
_copy(test_label[i], os.path.join(Data,'test/labels'))
generate_dataset_yaml(
save_path=os.path.join(Data, 'dataset.yaml'),
train_path=os.path.join(Data,'train/images'),
val_path=os.path.join(Data,'val/images'),
test_path=os.path.join(Data,'test/images'),
class_names=class_names
download_and_save_images_from_oss(
yaml_name=config_name,
where_clause=f"{column_name} = '{search_condition}'",
image_dir=image_dir,
label_dir=label_dir,
table_name=table_name,
)
logger.info("数据集划分完成")
def _copy(from_path, to_path):
try:
shutil.copy(from_path, to_path)
except Exception as e:
print(f"复制文件时出错: {e}")
return aim_path, image_dir, label_dir
def toLabelPath(img_path, label_path):
img = os.path.basename(img_path)
label = img.replace('.jpg', '.txt')
return os.path.join(label_path, label)
def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names):
dataset_yaml = {
'train': train_path.replace('\\', '/'),
'val': val_path.replace('\\', '/'),
'test': test_path.replace('\\', '/'),
'nc': len(class_names),
'names': list(class_names.values())
}
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(dataset_yaml, f, allow_unicode=True)
# ------------------ 标签修正与数据打乱 ------------------
def broken_and_convert_txt_to_yolo_format(
aim_path, # 数据集根目录
output_path, # 打乱并输出后的数据集目录
image_dir, # 图片目录
label_dir, # 标签目录
class_names # 数据集类别列表
):
process_files_in_folder(label_dir) # 修正标签为 YOLO 格式
broken_main(aim_path, output_path, class_names) # 打乱数据集并生成 dataset.yaml
yaml_path = os.path.join(output_path, 'dataset.yaml')
return output_path, yaml_path
def delete_folder(folder_path, logger):
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
logger.info(f"已删除文件夹: {folder_path}")
####################################### 训练 #######################################
def train(project_name, yaml_path, default_model_path, logger):
# ------------------ 获取最新 pt 模型 ------------------
def get_latest_pt(project_dir, pt_path):
"""
检查指定训练输出目录是否有最新 .pt 模型文件
若存在则返回最新文件路径否则返回传入的 pt_path
"""
if not os.path.exists(project_dir):
print(f"[INFO] 项目目录 {project_dir} 不存在,使用传入模型 {pt_path}")
return pt_path
pt_files = glob(os.path.join(project_dir, "*.pt"))
if not pt_files:
print(f"[INFO] 目录中无 pt 文件,使用传入模型 {pt_path}")
return pt_path
latest_pt = max(pt_files, key=os.path.getmtime)
print(f"[INFO] 检测到最新模型: {latest_pt}")
return latest_pt
# ------------------ 训练 ------------------
def train(
yaml_path, # YOLO 数据集配置文件路径
pt_path, # 用于训练的初始权重 .pt 文件路径
imgsz, # 输入图片分辨率
epochs, # 训练轮次
device, # GPU 设备索引列表,例如 [0] 或 [0,1]
hsv_v, # 图像亮度增强系数
cos_lr, # 是否使用余弦学习率
batch, # 批量大小
project_dir # 训练输出目录(模型权重、日志等)
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
print(f"[INFO] Using device: {device}")
model_path = get_last_model_from_log(project_name, default_model_path)
logger.info(f"加载模型: {model_path}")
model = YOLO(model_path).to(device)
pt_path = get_latest_pt(project_dir, pt_path) # 自动检测最新 pt 文件
model = YOLO(pt_path).to(device)
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
model.train(
data=yaml_path,
epochs=200,
pretrained=True,
patience=50,
imgsz=640,
epochs=epochs,
imgsz=imgsz,
device=device,
hsv_v=hsv_v,
cos_lr=cos_lr,
batch=batch,
project=project_dir,
)
# ------------------ 主流程 ------------------
def train_main(
# OSS 下载参数
config_name, # sql 配置文件名
table_name, # sql 表名
column_name, # sql 表中列名
search_condition, # sql 数据筛选条件
# 数据集路径
aim_path, # 本地数据集根目录,打乱后的
image_dir, # 本地图片保存目录
label_dir, # 本地标签保存目录
output_path, # 打乱并输出后的数据集目录
# YOLO 训练参数
pt_path, # 初始权重文件路径
imgsz, # 输入图片分辨率
epochs, # 训练轮次
device, # GPU 设备索引列表
hsv_v, # 图像亮度增强系数
cos_lr, # 是否使用余弦学习率
batch, # 批量大小
project_dir, # 训练输出目录
# 类别
class_names # 数据集类别列表
):
aim_path, image_dir, label_dir = download_images_and_labels(
config_name, table_name, column_name, search_condition,
aim_path, image_dir, label_dir
)
output_path, yaml_path = broken_and_convert_txt_to_yolo_format(
aim_path, output_path, image_dir, label_dir, class_names
)
train(
yaml_path=yaml_path,
pt_path=pt_path,
imgsz=imgsz,
epochs=epochs,
device=device,
hsv_v=hsv_v,
cos_lr=cos_lr,
batch=batch,
project_dir=project_dir
)
# ------------------ 执行 ------------------
if __name__ == "__main__":
train_main(
config_name="config",
table_name="aidataset",
column_name="image_url",
search_condition="your_search_id",
aim_path="./datasets/aidataset_dataset",
image_dir="./dataset/aidataset_dataset_images",
label_dir="./dataset/aidataset_dataset_labels",
output_path="./my_dataset",
pt_path="custom_model.pt",
imgsz=800,
epochs=500,
device=[0],
workers=0,
project=project_name,
name=current_date,
)
trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt')
if os.path.exists(trained_model_path):
logger.info(f"Saved last model path: {trained_model_path}")
####################################### 自动训练 #######################################
def auto_train(db_host, db_database, db_user, db_password, db_port, model_id,
img_path='./dataset/images', label_path='./dataset/labels',
new_path='./datasets', split_list=[0.7, 0.2, 0.1],
class_names=None, project_name='default_project'):
if class_names is None:
class_names = {}
logger = setup_logger(project_name)
delete_folder('dataset', logger)
delete_folder('datasets', logger)
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger)
process_files_in_folder(img_path, logger)
label_count = count_labels_by_class(label_path)
logger.info(f"标签统计: {label_count}")
split_img(img_path, label_path, split_list, new_path, class_names, logger)
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
logger.info(f"训练前基线评估: {base_metrics}")
delete_folder('dataset', logger)
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger)
logger.info("训练流程执行完成")
def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name):
imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name)
auto_train(
db_host=db_host,
db_database=db_database,
db_user=db_user,
db_password=db_password,
db_port=db_port,
model_id=model_id,
imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递
label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递
new_path='./datasets',
split_list=[0.7, 0.2, 0.1],
class_names={'0': 'human', '1': 'car'},
project_name='my_project'
)
####################################### 主入口 #######################################
if __name__ == '__main__':
down_and_train(
db_host='222.212.85.86',
db_database='your_database_name',
db_user='postgres',
db_password='postgres',
db_port='5432',
model_id='best.pt',
img_path='./dataset/images', #before broken img path
label_path='./dataset/labels',#before broken labels path
new_path='./datasets', #after broken path
split_list=[0.7, 0.2, 0.1],
class_names={'0': 'human', '1': 'car'},
project_name='my_project'
hsv_v=0.3,
cos_lr=True,
batch=8,
project_dir="./my_train_runs",
class_names=['person','car']
)