yoooooger
This commit is contained in:
parent
49f00429b9
commit
04c0b96f33
Binary file not shown.
BIN
Ai_tottle/aboutdataset/__pycache__/download_oss.cpython-312.pyc
Normal file
BIN
Ai_tottle/aboutdataset/__pycache__/download_oss.cpython-312.pyc
Normal file
Binary file not shown.
@ -7,8 +7,15 @@ from sanic_cors import CORS
|
|||||||
# ourself imports
|
# ourself imports
|
||||||
from ai_image import process_images
|
from ai_image import process_images
|
||||||
from map_find import map_process_images
|
from map_find import map_process_images
|
||||||
from yolo_train import auto_train,query_progress
|
from yolo_train import train_main
|
||||||
from yolo_photo import map_process_images_with_progress
|
from yolo_photo import map_process_images_with_progress
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from typing import List, Dict
|
||||||
|
import threading
|
||||||
|
import torch
|
||||||
|
import uuid
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
|
||||||
# set up logging
|
# set up logging
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
@ -237,83 +244,137 @@ async def yolo_detect_api(request):
|
|||||||
"message": f"Internal server error: {str(e)}"
|
"message": f"Internal server error: {str(e)}"
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
# YOLO auto_train API
|
#--------------------------------------------------------------------------yolo训练相关的API----------------------------------------------------------------########################################
|
||||||
|
#创建yolo训练的蓝图
|
||||||
|
|
||||||
|
MAX_CONCURRENT_JOBS = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||||
|
tasks: Dict[str, Dict] = {}
|
||||||
|
task_queue = Queue()
|
||||||
|
active_jobs: List[str] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ 参数模型 ------------------
|
||||||
|
class TrainRequest(BaseModel):
|
||||||
|
config_name: str
|
||||||
|
table_name: str
|
||||||
|
column_name: str
|
||||||
|
search_condition: str
|
||||||
|
aim_path: str
|
||||||
|
image_dir: str
|
||||||
|
label_dir: str
|
||||||
|
output_path: str
|
||||||
|
pt_path: str
|
||||||
|
imgsz: int
|
||||||
|
epochs: int
|
||||||
|
device: List[int]
|
||||||
|
hsv_v: float
|
||||||
|
cos_lr: bool
|
||||||
|
batch: int
|
||||||
|
project_dir: str
|
||||||
|
class_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ 核心执行函数 ------------------
|
||||||
|
def run_training(task_id: str, params: TrainRequest):
|
||||||
|
try:
|
||||||
|
with lock:
|
||||||
|
active_jobs.append(task_id)
|
||||||
|
tasks[task_id]["status"] = "running"
|
||||||
|
|
||||||
|
train_main(
|
||||||
|
config_name=params.config_name,
|
||||||
|
table_name=params.table_name,
|
||||||
|
column_name=params.column_name,
|
||||||
|
search_condition=params.search_condition,
|
||||||
|
aim_path=params.aim_path,
|
||||||
|
image_dir=params.image_dir,
|
||||||
|
label_dir=params.label_dir,
|
||||||
|
output_path=params.output_path,
|
||||||
|
pt_path=params.pt_path,
|
||||||
|
imgsz=params.imgsz,
|
||||||
|
epochs=params.epochs,
|
||||||
|
device=params.device,
|
||||||
|
hsv_v=params.hsv_v,
|
||||||
|
cos_lr=params.cos_lr,
|
||||||
|
batch=params.batch,
|
||||||
|
project_dir=params.project_dir,
|
||||||
|
class_names=params.class_names
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks[task_id]["status"] = "finished"
|
||||||
|
except Exception as e:
|
||||||
|
tasks[task_id]["status"] = "failed"
|
||||||
|
tasks[task_id]["error"] = str(e)
|
||||||
|
finally:
|
||||||
|
with lock:
|
||||||
|
if task_id in active_jobs:
|
||||||
|
active_jobs.remove(task_id)
|
||||||
|
schedule_next_job()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ 调度器 ------------------
|
||||||
|
def schedule_next_job():
|
||||||
|
with lock:
|
||||||
|
while len(active_jobs) < MAX_CONCURRENT_JOBS and not task_queue.empty():
|
||||||
|
next_id = task_queue.get()
|
||||||
|
params = tasks[next_id]["params"]
|
||||||
|
t = threading.Thread(target=run_training, args=(next_id, params), daemon=True)
|
||||||
|
t.start()
|
||||||
|
# ------------------ 接口 ------------------
|
||||||
@yolo_tile_blueprint.post("/train")
|
@yolo_tile_blueprint.post("/train")
|
||||||
async def yolo_train_api(request):
|
async def submit_train_job(request):
|
||||||
"""
|
|
||||||
auto_train
|
|
||||||
input JSON:
|
|
||||||
{
|
|
||||||
"db_host": str,
|
|
||||||
"db_database": str,
|
|
||||||
"db_user": str,
|
|
||||||
"db_password": str,
|
|
||||||
"db_port": int,
|
|
||||||
"model_id": int,
|
|
||||||
"img_path": str,
|
|
||||||
"label_path": str,
|
|
||||||
"new_path": str,
|
|
||||||
"split_list": List[float],
|
|
||||||
"class_names": Optional[List[str]],
|
|
||||||
"project_name": str
|
|
||||||
}
|
|
||||||
output JSON:
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "Train finished",
|
|
||||||
"project_name": project_name,
|
|
||||||
"label_count": label_count,
|
|
||||||
"base_metrics": base_metrics,
|
|
||||||
"final_metrics": final_metrics
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
data = request.json
|
data = request.json
|
||||||
if not data:
|
params = TrainRequest(**data)
|
||||||
return json_response({"status": "error", "message": "data is required"}, status=400)
|
except ValidationError as e:
|
||||||
# Do the training in a separate thread to avoid blocking the event loop
|
return json({"success": False, "error": e.errors()})
|
||||||
result = await asyncio.to_thread(
|
|
||||||
auto_train,
|
|
||||||
data
|
|
||||||
)
|
|
||||||
# return the result as JSON response
|
|
||||||
return json_response(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
task_id = str(uuid.uuid4())
|
||||||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
tasks[task_id] = {"status": "queued", "params": params}
|
||||||
return json_response({
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Internal server error: {str(e)}"
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
# access the training progress
|
with lock:
|
||||||
@yolo_tile_blueprint.get("/progress/<project_name>")
|
if len(active_jobs) < MAX_CONCURRENT_JOBS:
|
||||||
async def yolo_train_progress(request, project_name):
|
t = threading.Thread(target=run_training, args=(task_id, params), daemon=True)
|
||||||
'''
|
t.start()
|
||||||
input:
|
else:
|
||||||
if want to query the latest progress: GET /yolo/progress/my_project
|
task_queue.put(task_id)
|
||||||
if want to query the progress at a specific time: GET /yolo/progress/my_project?run_time=20250902_1012
|
tasks[task_id]["status"] = "waiting"
|
||||||
output JSON:
|
|
||||||
{
|
|
||||||
"status": "ok",
|
|
||||||
"run_time": "20250902_1012",
|
|
||||||
"progress": {
|
|
||||||
"epoch": 12,
|
|
||||||
"precision": 0.72,
|
|
||||||
"recall": 0.64,
|
|
||||||
"mAP50": 0.68,
|
|
||||||
"mAP50-95": 0.42
|
|
||||||
}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
run_time = request.args.get("run_time") # get the run_time from the query string
|
|
||||||
# query the progress from the database
|
|
||||||
if not run_time:
|
|
||||||
run_time = None # if not provided, query the latest progress
|
|
||||||
|
|
||||||
result = await asyncio.to_thread(query_progress, project_name, run_time)
|
return json({"success": True, "task_id": task_id, "message": "任务已提交"})
|
||||||
return json_response(result)
|
|
||||||
|
|
||||||
|
|
||||||
|
@yolo_tile_blueprint.get("/task_status/<task_id>")
|
||||||
|
async def task_status(request, task_id: str):
|
||||||
|
if task_id not in tasks:
|
||||||
|
return json({"success": False, "message": "任务ID不存在"})
|
||||||
|
|
||||||
|
task_info = tasks[task_id]
|
||||||
|
return json({
|
||||||
|
"success": True,
|
||||||
|
"status": task_info["status"],
|
||||||
|
"error": task_info.get("error", None)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@yolo_tile_blueprint.get("/tasks")
|
||||||
|
async def all_tasks(request):
|
||||||
|
return json({
|
||||||
|
tid: {"status": info["status"]}
|
||||||
|
for tid, info in tasks.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@yolo_tile_blueprint.get("/system_status")
|
||||||
|
async def system_status(request):
|
||||||
|
gpu_available = torch.cuda.is_available()
|
||||||
|
return json({
|
||||||
|
"gpu_available": gpu_available,
|
||||||
|
"max_concurrent": MAX_CONCURRENT_JOBS,
|
||||||
|
"running_jobs": len(active_jobs),
|
||||||
|
"waiting_jobs": task_queue.qsize(),
|
||||||
|
"active_task_ids": active_jobs
|
||||||
|
})
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
||||||
|
|
||||||
@ -8,6 +8,6 @@ minio:
|
|||||||
sql:
|
sql:
|
||||||
host: '222.212.85.86'
|
host: '222.212.85.86'
|
||||||
port: 5432
|
port: 5432
|
||||||
dbname: 'postgres'
|
dbname: 'smart_dev'
|
||||||
user: 'postgres'
|
user: 'postgres'
|
||||||
password: 'root'
|
password: 'root'
|
||||||
87
Ai_tottle/train/broken.py
Normal file
87
Ai_tottle/train/broken.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
def split_img(img_path, label_path, split_list, output_path,class_names=[
|
||||||
|
'people',
|
||||||
|
'car',
|
||||||
|
'truck',
|
||||||
|
'bicycle',
|
||||||
|
'tricycle',
|
||||||
|
'ship']):
|
||||||
|
try:
|
||||||
|
# 创建目标目录结构
|
||||||
|
for sub in ['images/train', 'images/val', 'images/test',
|
||||||
|
'labels/train', 'labels/val', 'labels/test']:
|
||||||
|
os.makedirs(os.path.join(output_path, sub), exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ 文件目录创建失败: {e}')
|
||||||
|
return
|
||||||
|
|
||||||
|
train, val, test = split_list
|
||||||
|
all_imgs = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.png'))]
|
||||||
|
all_img_paths = [os.path.join(img_path, f) for f in all_imgs]
|
||||||
|
|
||||||
|
# 分配训练集
|
||||||
|
train_imgs = random.sample(all_img_paths, int(train * len(all_img_paths)))
|
||||||
|
move_set(train_imgs, label_path, os.path.join(output_path, 'images/train'), os.path.join(output_path, 'labels/train'))
|
||||||
|
for f in train_imgs: all_img_paths.remove(f)
|
||||||
|
|
||||||
|
# 分配验证集
|
||||||
|
val_imgs = random.sample(all_img_paths, int(val / (val + test) * len(all_img_paths)))
|
||||||
|
move_set(val_imgs, label_path, os.path.join(output_path, 'images/val'), os.path.join(output_path, 'labels/val'))
|
||||||
|
for f in val_imgs: all_img_paths.remove(f)
|
||||||
|
|
||||||
|
# 剩余分配给测试集
|
||||||
|
test_imgs = all_img_paths
|
||||||
|
move_set(test_imgs, label_path, os.path.join(output_path, 'images/test'), os.path.join(output_path, 'labels/test'))
|
||||||
|
|
||||||
|
# 生成 dataset.yaml
|
||||||
|
generate_yaml(output_path, class_names)
|
||||||
|
|
||||||
|
def move_set(img_list, label_root, dst_img_dir, dst_label_dir):
|
||||||
|
for img_path in tqdm(img_list, desc=f'Copying to {os.path.basename(dst_img_dir)}', ncols=80):
|
||||||
|
base = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
label_path = os.path.join(label_root, base + '.txt')
|
||||||
|
|
||||||
|
shutil.copy(img_path, os.path.join(dst_img_dir, os.path.basename(img_path)))
|
||||||
|
if os.path.exists(label_path):
|
||||||
|
shutil.copy(label_path, os.path.join(dst_label_dir, base + '.txt'))
|
||||||
|
|
||||||
|
def generate_yaml(dataset_root, class_names):
|
||||||
|
yaml_content = {
|
||||||
|
'train': os.path.join('images/train'),
|
||||||
|
'val': os.path.join('images/val'),
|
||||||
|
'test': os.path.join('images/test'),
|
||||||
|
'nc': len(class_names),
|
||||||
|
'names': class_names
|
||||||
|
}
|
||||||
|
with open(os.path.join(dataset_root, 'dataset.yaml'), 'w') as f:
|
||||||
|
yaml.dump(yaml_content, f, default_flow_style=False)
|
||||||
|
print(f"✅ 已生成 YAML: {os.path.join(dataset_root, 'dataset.yaml')}")
|
||||||
|
|
||||||
|
def broken_main(aim_path, output_path,class_names=[
|
||||||
|
'people',
|
||||||
|
'car',
|
||||||
|
'truck',
|
||||||
|
'bicycle',
|
||||||
|
'tricycle',
|
||||||
|
'ship']):
|
||||||
|
img_path = os.path.join(aim_path, 'images')
|
||||||
|
label_path = os.path.join(aim_path, 'labels')
|
||||||
|
split_ratio = [0.7, 0.2, 0.1]
|
||||||
|
split_img(img_path, label_path, split_ratio, output_path,class_names)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
broken_main(
|
||||||
|
r"D:\Users\76118\Downloads\stanford_campus_dataset\filtered",
|
||||||
|
r"D:\work\develop\AI\数据集\output",
|
||||||
|
class_names=[
|
||||||
|
'people',
|
||||||
|
'car',
|
||||||
|
'truck',
|
||||||
|
'bicycle',
|
||||||
|
'tricycle',]
|
||||||
|
)
|
||||||
85
Ai_tottle/train/let_txt_to_true.py
Normal file
85
Ai_tottle/train/let_txt_to_true.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import os
|
||||||
|
import stat
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def make_writable(file_path):
|
||||||
|
os.chmod(file_path, stat.S_IWRITE)
|
||||||
|
|
||||||
|
|
||||||
|
def process_files_in_folder(folder_path):
|
||||||
|
for root, _, files in os.walk(folder_path):
|
||||||
|
for file_name in files:
|
||||||
|
if file_name.endswith(".txt"):
|
||||||
|
file_path = os.path.join(root, file_name)
|
||||||
|
|
||||||
|
# 确保文件可写
|
||||||
|
make_writable(file_path)
|
||||||
|
|
||||||
|
# 读取文件内容并进行处理
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
|
||||||
|
processed_lines = []
|
||||||
|
for line in lines:
|
||||||
|
numbers = line.split()
|
||||||
|
processed_numbers = []
|
||||||
|
|
||||||
|
# 确保第一列为整数 0 或 1,不处理为浮点数
|
||||||
|
if (
|
||||||
|
numbers[0] == "0"
|
||||||
|
or numbers[0] == "1"
|
||||||
|
or numbers[0] == "2"
|
||||||
|
or numbers[0] == "3"
|
||||||
|
or numbers[0] == "4"
|
||||||
|
or numbers[0] == "5"
|
||||||
|
or numbers[0] == "6"
|
||||||
|
or numbers[0] == "7"
|
||||||
|
or numbers[0] == "8"
|
||||||
|
or numbers[0] == "9"
|
||||||
|
or numbers[0] == "10"
|
||||||
|
or numbers[0] == "11"
|
||||||
|
or numbers[0] == "12"
|
||||||
|
or numbers[0] == "13"
|
||||||
|
or numbers[0] == "14"
|
||||||
|
or numbers[0] == "15"
|
||||||
|
or numbers[0] == "16"
|
||||||
|
or numbers[0] == "17"
|
||||||
|
or numbers[0] == "18"
|
||||||
|
):
|
||||||
|
processed_numbers.append(numbers[0])
|
||||||
|
else:
|
||||||
|
print(f"Unexpected value in first column: {numbers[0]}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据
|
||||||
|
skip_line = False # 用于标记是否跳过这一行
|
||||||
|
for number in numbers[1:]:
|
||||||
|
try:
|
||||||
|
number = float(number)
|
||||||
|
if math.isnan(number): # 检查是否为NaN
|
||||||
|
skip_line = True
|
||||||
|
print(
|
||||||
|
f"NaN detected in file: {file_path}, line: {line}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if number < 0:
|
||||||
|
number = abs(number) # 将负数转换为正数
|
||||||
|
processed_numbers.append(str(number)) # 保留原始格式
|
||||||
|
except ValueError:
|
||||||
|
processed_numbers.append(number) # 非数字列保持原样
|
||||||
|
|
||||||
|
# 如果该行没有NaN数据,则加入结果列表
|
||||||
|
if not skip_line:
|
||||||
|
processed_line = " ".join(processed_numbers)
|
||||||
|
processed_lines.append(processed_line)
|
||||||
|
|
||||||
|
# 将处理后的内容写回文件
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
file.write("\n".join(processed_lines))
|
||||||
|
print(f"Finished processing: {file_path}")
|
||||||
|
|
||||||
|
# 指定文件夹路径
|
||||||
|
folder_path = r"G:\dataset\PCS\before\labels"
|
||||||
|
#run the function
|
||||||
|
process_files_in_folder(folder_path)
|
||||||
25
Ai_tottle/train/train.py
Normal file
25
Ai_tottle/train/train.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from ultralytics import YOLO
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# 检查CUDA是否可用
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = YOLO("runs/detect/train6/weights/last.pt").to(device)
|
||||||
|
|
||||||
|
# 设置新的分辨率
|
||||||
|
imgsz = 1024 # 这里将图像尺寸调整为 1280x1280(你可以根据显存调整尺寸)
|
||||||
|
|
||||||
|
# 训练模型,传入增强参数
|
||||||
|
model.train(
|
||||||
|
data="dataset/dataset.yaml", # 你的数据集配置文件
|
||||||
|
epochs=1000, # 训练轮次
|
||||||
|
imgsz=imgsz, # 使用更高的分辨率
|
||||||
|
device=[1], # 使用第一块 GPU(如果有多个 GPU,可以调整)
|
||||||
|
hsv_v=0.3, # 修改图像亮度的一部分,帮助模型在不同光照条件下表现良好
|
||||||
|
cos_lr=True, # 启用余弦学习率调度
|
||||||
|
batch = -1, # 自动调整批量大小以适应显存
|
||||||
|
)
|
||||||
|
|
||||||
@ -1,394 +1,166 @@
|
|||||||
""""
|
|
||||||
main()
|
|
||||||
|
|
|
||||||
v
|
|
||||||
setup_logger(project)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
get_last_model_path(project)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
+-------------------------+
|
|
||||||
| 有 last.pt | 无 last.pt |
|
|
||||||
+-------------------------+
|
|
||||||
| |
|
|
||||||
v v
|
|
||||||
load_last_model() start_new_training()
|
|
||||||
| |
|
|
||||||
+-------+--------+
|
|
||||||
|
|
|
||||||
v
|
|
||||||
check_dataset(root)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
split_dataset(root, ratios)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
clean_labels(root)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
generate_yaml(dataset_dir)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
train_yolo(model, data_yaml)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
保存 last.pt
|
|
||||||
|
|
|
||||||
v
|
|
||||||
logger.info("Saved last model path")
|
|
||||||
|
|
|
||||||
v
|
|
||||||
写入 logs/{project}.log
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
from glob import glob
|
||||||
import datetime
|
|
||||||
import torch
|
|
||||||
from ultralytics import YOLO
|
|
||||||
import random
|
|
||||||
import math
|
|
||||||
import stat
|
|
||||||
import yaml
|
|
||||||
import psycopg2
|
|
||||||
from psycopg2 import OperationalError
|
|
||||||
from collections import Counter
|
|
||||||
import pandas as pd
|
|
||||||
import logging
|
|
||||||
from tqdm import tqdm
|
|
||||||
import miniohelp as miniohelp
|
|
||||||
from aboutdataset.download_oss import download_and_save_images_from_oss
|
from aboutdataset.download_oss import download_and_save_images_from_oss
|
||||||
|
from train.let_txt_to_true import process_files_in_folder
|
||||||
|
from train.broken import broken_main
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import torch
|
||||||
|
|
||||||
######################################## Logging ########################################
|
|
||||||
def setup_logger(project: str):
|
|
||||||
os.makedirs("logs", exist_ok=True)
|
|
||||||
log_file = os.path.join("logs", f"{project}.log")
|
|
||||||
|
|
||||||
logger = logging.getLogger(project)
|
# ------------------ 下载图片和标签 ------------------
|
||||||
if not logger.handlers:
|
def download_images_and_labels(
|
||||||
logger.setLevel(logging.INFO)
|
config_name, # OSS 配置文件名,用于读取连接信息
|
||||||
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
|
table_name, # OSS 表名,指定下载数据的表
|
||||||
|
column_name, # OSS 表中图片 URL 列名
|
||||||
|
search_condition, # 筛选条件,用于查询 OSS 数据
|
||||||
|
aim_path, # 本地保存数据集根目录
|
||||||
|
image_dir, # 本地保存图片的目录
|
||||||
|
label_dir # 本地保存标签 txt 的目录
|
||||||
|
):
|
||||||
|
os.makedirs(aim_path, exist_ok=True)
|
||||||
|
os.makedirs(image_dir, exist_ok=True)
|
||||||
|
os.makedirs(label_dir, exist_ok=True)
|
||||||
|
|
||||||
fh = logging.FileHandler(log_file, encoding="utf-8")
|
download_and_save_images_from_oss(
|
||||||
fh.setFormatter(formatter)
|
yaml_name=config_name,
|
||||||
sh = logging.StreamHandler()
|
where_clause=f"{column_name} = '{search_condition}'",
|
||||||
sh.setFormatter(formatter)
|
image_dir=image_dir,
|
||||||
|
label_dir=label_dir,
|
||||||
logger.addHandler(fh)
|
table_name=table_name,
|
||||||
logger.addHandler(sh)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"):
|
|
||||||
"""
|
|
||||||
从日志解析上一次训练的 last.pt 路径
|
|
||||||
如果找不到则返回 default_model
|
|
||||||
支持 default_model 为 .pt 或 .yaml
|
|
||||||
"""
|
|
||||||
log_file = os.path.join("logs", f"{project}.log")
|
|
||||||
if not os.path.exists(log_file):
|
|
||||||
return default_model
|
|
||||||
|
|
||||||
with open(log_file, "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
|
|
||||||
for line in reversed(lines):
|
|
||||||
if "Saved last model path:" in line:
|
|
||||||
path = line.strip().split("Saved last model path:")[-1].strip()
|
|
||||||
if os.path.exists(path):
|
|
||||||
return path
|
|
||||||
|
|
||||||
return default_model
|
|
||||||
|
|
||||||
def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name):
|
|
||||||
"""
|
|
||||||
yaml_name='config'
|
|
||||||
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
|
|
||||||
image_dir='images'
|
|
||||||
label_dir='labels'
|
|
||||||
table_name = 'aidataset'
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(image_dir, label_dir)
|
|
||||||
"""
|
|
||||||
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
|
|
||||||
return image_dir, label_dir
|
|
||||||
|
|
||||||
####################################### 工具函数 #######################################
|
|
||||||
def count_labels_by_class(label_dir):
|
|
||||||
class_counter = Counter()
|
|
||||||
for file in os.listdir(label_dir):
|
|
||||||
if file.endswith('.txt'):
|
|
||||||
with open(os.path.join(label_dir, file), 'r') as f:
|
|
||||||
for line in f:
|
|
||||||
class_id = line.strip().split()[0]
|
|
||||||
class_counter[class_id] += 1
|
|
||||||
return dict(class_counter)
|
|
||||||
|
|
||||||
def evaluate_model_per_class(model_path, dataset_yaml, class_names):
|
|
||||||
model = YOLO(model_path)
|
|
||||||
metrics = model.val(data=dataset_yaml, split='val')
|
|
||||||
class_ids = range(len(metrics.box.p))
|
|
||||||
results = {}
|
|
||||||
for i in class_ids:
|
|
||||||
name = class_names.get(str(i), str(i))
|
|
||||||
results[name] = {
|
|
||||||
"precision": float(metrics.box.p[i]),
|
|
||||||
"recall": float(metrics.box.r[i]),
|
|
||||||
"mAP50": float(metrics.box.map50[i]),
|
|
||||||
"mAP50_95": float(metrics.box.map[i])
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
|
|
||||||
def link_database(db_database, db_user, db_password, db_host, db_port, search_query):
|
|
||||||
try:
|
|
||||||
with psycopg2.connect(
|
|
||||||
database=db_database,
|
|
||||||
user=db_user,
|
|
||||||
password=db_password,
|
|
||||||
host=db_host,
|
|
||||||
port=db_port
|
|
||||||
) as conn:
|
|
||||||
with conn.cursor() as cur:
|
|
||||||
cur.execute(search_query)
|
|
||||||
records = cur.fetchall()
|
|
||||||
return records
|
|
||||||
except OperationalError as e:
|
|
||||||
print(f"数据库连接或查询时发生错误: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"发生了其他错误: {e}")
|
|
||||||
|
|
||||||
def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger):
|
|
||||||
search_query = f"SELECT * FROM aidataset WHERE model = '{model}';"
|
|
||||||
records = link_database(db_database, db_user, db_password, db_host, db_port, search_query)
|
|
||||||
if not records:
|
|
||||||
logger.warning("没有查询到数据。")
|
|
||||||
return
|
|
||||||
|
|
||||||
os.makedirs('./dataset/images', exist_ok=True)
|
|
||||||
os.makedirs('./dataset/labels', exist_ok=True)
|
|
||||||
|
|
||||||
for r in records:
|
|
||||||
img_path = r[4]
|
|
||||||
label_content = r[5]
|
|
||||||
|
|
||||||
local_img_name = img_path.split('/')[-1]
|
|
||||||
local_img_path = os.path.join('./dataset/images', local_img_name)
|
|
||||||
miniohelp.downFile(img_path, local_img_path)
|
|
||||||
|
|
||||||
txt_name = os.path.splitext(local_img_name)[0] + '.txt'
|
|
||||||
txt_path = os.path.join('./dataset/labels', txt_name)
|
|
||||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(label_content + '\n')
|
|
||||||
|
|
||||||
logger.info("数据下载完成")
|
|
||||||
|
|
||||||
def make_writable(file_path):
|
|
||||||
os.chmod(file_path, stat.S_IWRITE)
|
|
||||||
|
|
||||||
def process_files_in_folder(folder_path, logger):
|
|
||||||
for root, _, files in os.walk(folder_path):
|
|
||||||
for file_name in files:
|
|
||||||
if file_name.endswith('.txt'):
|
|
||||||
file_path = os.path.join(root, file_name)
|
|
||||||
make_writable(file_path)
|
|
||||||
|
|
||||||
with open(file_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
|
|
||||||
processed_lines = []
|
|
||||||
for line in lines:
|
|
||||||
numbers = line.split()
|
|
||||||
processed_numbers = []
|
|
||||||
if numbers[0].isdigit():
|
|
||||||
processed_numbers.append(numbers[0])
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected value in first column: {numbers[0]}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
skip_line = False
|
|
||||||
for number in numbers[1:]:
|
|
||||||
try:
|
|
||||||
number = float(number)
|
|
||||||
if math.isnan(number):
|
|
||||||
skip_line = True
|
|
||||||
logger.warning(f"NaN detected in {file_path}: {line}")
|
|
||||||
break
|
|
||||||
if number < 0:
|
|
||||||
number = abs(number)
|
|
||||||
processed_numbers.append(str(number))
|
|
||||||
except ValueError:
|
|
||||||
processed_numbers.append(number)
|
|
||||||
|
|
||||||
if not skip_line:
|
|
||||||
processed_line = ' '.join(processed_numbers)
|
|
||||||
processed_lines.append(processed_line)
|
|
||||||
|
|
||||||
with open(file_path, 'w') as file:
|
|
||||||
file.write('\n'.join(processed_lines))
|
|
||||||
logger.info(f"Processed {file_path}")
|
|
||||||
|
|
||||||
def split_img(img_path, label_path, split_list, new_path, class_names, logger):
|
|
||||||
try:
|
|
||||||
Data = os.path.abspath(new_path)
|
|
||||||
os.makedirs(Data, exist_ok=True)
|
|
||||||
dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels']
|
|
||||||
for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'文件目录创建失败: {e}')
|
|
||||||
return
|
|
||||||
|
|
||||||
train, val, test = split_list
|
|
||||||
all_img = os.listdir(img_path)
|
|
||||||
all_img_path = [os.path.join(img_path, img) for img in all_img]
|
|
||||||
|
|
||||||
train_img = random.sample(all_img_path, int(train * len(all_img_path)))
|
|
||||||
train_label = [toLabelPath(img, label_path) for img in train_img]
|
|
||||||
for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
|
|
||||||
_copy(train_img[i], os.path.join(Data,'train/images'))
|
|
||||||
_copy(train_label[i], os.path.join(Data,'train/labels'))
|
|
||||||
all_img_path.remove(train_img[i])
|
|
||||||
|
|
||||||
val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
|
|
||||||
val_label = [toLabelPath(img, label_path) for img in val_img]
|
|
||||||
for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
|
|
||||||
_copy(val_img[i], os.path.join(Data,'val/images'))
|
|
||||||
_copy(val_label[i], os.path.join(Data,'val/labels'))
|
|
||||||
all_img_path.remove(val_img[i])
|
|
||||||
|
|
||||||
test_img = all_img_path
|
|
||||||
test_label = [toLabelPath(img, label_path) for img in test_img]
|
|
||||||
for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
|
|
||||||
_copy(test_img[i], os.path.join(Data,'test/images'))
|
|
||||||
_copy(test_label[i], os.path.join(Data,'test/labels'))
|
|
||||||
|
|
||||||
generate_dataset_yaml(
|
|
||||||
save_path=os.path.join(Data, 'dataset.yaml'),
|
|
||||||
train_path=os.path.join(Data,'train/images'),
|
|
||||||
val_path=os.path.join(Data,'val/images'),
|
|
||||||
test_path=os.path.join(Data,'test/images'),
|
|
||||||
class_names=class_names
|
|
||||||
)
|
)
|
||||||
logger.info("数据集划分完成")
|
|
||||||
|
|
||||||
def _copy(from_path, to_path):
|
return aim_path, image_dir, label_dir
|
||||||
try:
|
|
||||||
shutil.copy(from_path, to_path)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"复制文件时出错: {e}")
|
|
||||||
|
|
||||||
def toLabelPath(img_path, label_path):
|
|
||||||
img = os.path.basename(img_path)
|
|
||||||
label = img.replace('.jpg', '.txt')
|
|
||||||
return os.path.join(label_path, label)
|
|
||||||
|
|
||||||
def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names):
|
# ------------------ 标签修正与数据打乱 ------------------
|
||||||
dataset_yaml = {
|
def broken_and_convert_txt_to_yolo_format(
|
||||||
'train': train_path.replace('\\', '/'),
|
aim_path, # 数据集根目录
|
||||||
'val': val_path.replace('\\', '/'),
|
output_path, # 打乱并输出后的数据集目录
|
||||||
'test': test_path.replace('\\', '/'),
|
image_dir, # 图片目录
|
||||||
'nc': len(class_names),
|
label_dir, # 标签目录
|
||||||
'names': list(class_names.values())
|
class_names # 数据集类别列表
|
||||||
}
|
):
|
||||||
with open(save_path, 'w', encoding='utf-8') as f:
|
process_files_in_folder(label_dir) # 修正标签为 YOLO 格式
|
||||||
yaml.dump(dataset_yaml, f, allow_unicode=True)
|
broken_main(aim_path, output_path, class_names) # 打乱数据集并生成 dataset.yaml
|
||||||
|
yaml_path = os.path.join(output_path, 'dataset.yaml')
|
||||||
|
return output_path, yaml_path
|
||||||
|
|
||||||
def delete_folder(folder_path, logger):
|
|
||||||
if os.path.exists(folder_path):
|
|
||||||
shutil.rmtree(folder_path)
|
|
||||||
logger.info(f"已删除文件夹: {folder_path}")
|
|
||||||
|
|
||||||
####################################### 训练 #######################################
|
# ------------------ 获取最新 pt 模型 ------------------
|
||||||
def train(project_name, yaml_path, default_model_path, logger):
|
def get_latest_pt(project_dir, pt_path):
|
||||||
|
"""
|
||||||
|
检查指定训练输出目录是否有最新 .pt 模型文件。
|
||||||
|
若存在则返回最新文件路径,否则返回传入的 pt_path。
|
||||||
|
"""
|
||||||
|
if not os.path.exists(project_dir):
|
||||||
|
print(f"[INFO] 项目目录 {project_dir} 不存在,使用传入模型 {pt_path}")
|
||||||
|
return pt_path
|
||||||
|
|
||||||
|
pt_files = glob(os.path.join(project_dir, "*.pt"))
|
||||||
|
if not pt_files:
|
||||||
|
print(f"[INFO] 目录中无 pt 文件,使用传入模型 {pt_path}")
|
||||||
|
return pt_path
|
||||||
|
|
||||||
|
latest_pt = max(pt_files, key=os.path.getmtime)
|
||||||
|
print(f"[INFO] 检测到最新模型: {latest_pt}")
|
||||||
|
return latest_pt
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ 训练 ------------------
|
||||||
|
def train(
|
||||||
|
yaml_path, # YOLO 数据集配置文件路径
|
||||||
|
pt_path, # 用于训练的初始权重 .pt 文件路径
|
||||||
|
imgsz, # 输入图片分辨率
|
||||||
|
epochs, # 训练轮次
|
||||||
|
device, # GPU 设备索引列表,例如 [0] 或 [0,1]
|
||||||
|
hsv_v, # 图像亮度增强系数
|
||||||
|
cos_lr, # 是否使用余弦学习率
|
||||||
|
batch, # 批量大小
|
||||||
|
project_dir # 训练输出目录(模型权重、日志等)
|
||||||
|
):
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logger.info(f"Using device: {device}")
|
print(f"[INFO] Using device: {device}")
|
||||||
|
|
||||||
model_path = get_last_model_from_log(project_name, default_model_path)
|
pt_path = get_latest_pt(project_dir, pt_path) # 自动检测最新 pt 文件
|
||||||
logger.info(f"加载模型: {model_path}")
|
|
||||||
model = YOLO(model_path).to(device)
|
model = YOLO(pt_path).to(device)
|
||||||
|
|
||||||
current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
|
||||||
model.train(
|
model.train(
|
||||||
data=yaml_path,
|
data=yaml_path,
|
||||||
epochs=200,
|
epochs=epochs,
|
||||||
pretrained=True,
|
imgsz=imgsz,
|
||||||
patience=50,
|
device=device,
|
||||||
imgsz=640,
|
hsv_v=hsv_v,
|
||||||
|
cos_lr=cos_lr,
|
||||||
|
batch=batch,
|
||||||
|
project=project_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ 主流程 ------------------
|
||||||
|
def train_main(
|
||||||
|
# OSS 下载参数
|
||||||
|
config_name, # sql 配置文件名
|
||||||
|
table_name, # sql 表名
|
||||||
|
column_name, # sql 表中列名
|
||||||
|
search_condition, # sql 数据筛选条件
|
||||||
|
# 数据集路径
|
||||||
|
aim_path, # 本地数据集根目录,打乱后的
|
||||||
|
image_dir, # 本地图片保存目录
|
||||||
|
label_dir, # 本地标签保存目录
|
||||||
|
output_path, # 打乱并输出后的数据集目录
|
||||||
|
# YOLO 训练参数
|
||||||
|
pt_path, # 初始权重文件路径
|
||||||
|
imgsz, # 输入图片分辨率
|
||||||
|
epochs, # 训练轮次
|
||||||
|
device, # GPU 设备索引列表
|
||||||
|
hsv_v, # 图像亮度增强系数
|
||||||
|
cos_lr, # 是否使用余弦学习率
|
||||||
|
batch, # 批量大小
|
||||||
|
project_dir, # 训练输出目录
|
||||||
|
# 类别
|
||||||
|
class_names # 数据集类别列表
|
||||||
|
):
|
||||||
|
aim_path, image_dir, label_dir = download_images_and_labels(
|
||||||
|
config_name, table_name, column_name, search_condition,
|
||||||
|
aim_path, image_dir, label_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path, yaml_path = broken_and_convert_txt_to_yolo_format(
|
||||||
|
aim_path, output_path, image_dir, label_dir, class_names
|
||||||
|
)
|
||||||
|
|
||||||
|
train(
|
||||||
|
yaml_path=yaml_path,
|
||||||
|
pt_path=pt_path,
|
||||||
|
imgsz=imgsz,
|
||||||
|
epochs=epochs,
|
||||||
|
device=device,
|
||||||
|
hsv_v=hsv_v,
|
||||||
|
cos_lr=cos_lr,
|
||||||
|
batch=batch,
|
||||||
|
project_dir=project_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ 执行 ------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_main(
|
||||||
|
config_name="config",
|
||||||
|
table_name="aidataset",
|
||||||
|
column_name="image_url",
|
||||||
|
search_condition="your_search_id",
|
||||||
|
aim_path="./datasets/aidataset_dataset",
|
||||||
|
image_dir="./dataset/aidataset_dataset_images",
|
||||||
|
label_dir="./dataset/aidataset_dataset_labels",
|
||||||
|
output_path="./my_dataset",
|
||||||
|
pt_path="custom_model.pt",
|
||||||
|
imgsz=800,
|
||||||
|
epochs=500,
|
||||||
device=[0],
|
device=[0],
|
||||||
workers=0,
|
hsv_v=0.3,
|
||||||
project=project_name,
|
cos_lr=True,
|
||||||
name=current_date,
|
batch=8,
|
||||||
|
project_dir="./my_train_runs",
|
||||||
|
class_names=['person','car']
|
||||||
)
|
)
|
||||||
|
|
||||||
trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt')
|
|
||||||
if os.path.exists(trained_model_path):
|
|
||||||
logger.info(f"Saved last model path: {trained_model_path}")
|
|
||||||
|
|
||||||
####################################### 自动训练 #######################################
|
|
||||||
def auto_train(db_host, db_database, db_user, db_password, db_port, model_id,
|
|
||||||
img_path='./dataset/images', label_path='./dataset/labels',
|
|
||||||
new_path='./datasets', split_list=[0.7, 0.2, 0.1],
|
|
||||||
class_names=None, project_name='default_project'):
|
|
||||||
if class_names is None:
|
|
||||||
class_names = {}
|
|
||||||
|
|
||||||
logger = setup_logger(project_name)
|
|
||||||
|
|
||||||
delete_folder('dataset', logger)
|
|
||||||
delete_folder('datasets', logger)
|
|
||||||
|
|
||||||
down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger)
|
|
||||||
process_files_in_folder(img_path, logger)
|
|
||||||
|
|
||||||
label_count = count_labels_by_class(label_path)
|
|
||||||
logger.info(f"标签统计: {label_count}")
|
|
||||||
|
|
||||||
split_img(img_path, label_path, split_list, new_path, class_names, logger)
|
|
||||||
|
|
||||||
base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names)
|
|
||||||
logger.info(f"训练前基线评估: {base_metrics}")
|
|
||||||
|
|
||||||
delete_folder('dataset', logger)
|
|
||||||
|
|
||||||
train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger)
|
|
||||||
|
|
||||||
logger.info("训练流程执行完成")
|
|
||||||
|
|
||||||
def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name):
|
|
||||||
|
|
||||||
imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name)
|
|
||||||
auto_train(
|
|
||||||
db_host=db_host,
|
|
||||||
db_database=db_database,
|
|
||||||
db_user=db_user,
|
|
||||||
db_password=db_password,
|
|
||||||
db_port=db_port,
|
|
||||||
model_id=model_id,
|
|
||||||
imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递
|
|
||||||
label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递
|
|
||||||
new_path='./datasets',
|
|
||||||
split_list=[0.7, 0.2, 0.1],
|
|
||||||
class_names={'0': 'human', '1': 'car'},
|
|
||||||
project_name='my_project'
|
|
||||||
)
|
|
||||||
|
|
||||||
####################################### 主入口 #######################################
|
|
||||||
if __name__ == '__main__':
|
|
||||||
down_and_train(
|
|
||||||
db_host='222.212.85.86',
|
|
||||||
db_database='your_database_name',
|
|
||||||
db_user='postgres',
|
|
||||||
db_password='postgres',
|
|
||||||
db_port='5432',
|
|
||||||
model_id='best.pt',
|
|
||||||
img_path='./dataset/images', #before broken img path
|
|
||||||
label_path='./dataset/labels',#before broken labels path
|
|
||||||
new_path='./datasets', #after broken path
|
|
||||||
split_list=[0.7, 0.2, 0.1],
|
|
||||||
class_names={'0': 'human', '1': 'car'},
|
|
||||||
project_name='my_project'
|
|
||||||
)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user