From 04c0b96f33f7338853be6f6fbd2556b5801c9d72 Mon Sep 17 00:00:00 2001 From: yooooger <761181201@qq.com> Date: Tue, 11 Nov 2025 09:43:25 +0800 Subject: [PATCH] yoooooger --- .../__pycache__/miniohelp.cpython-312.pyc | Bin 8388 -> 8450 bytes .../__pycache__/download_oss.cpython-312.pyc | Bin 0 -> 5302 bytes Ai_tottle/ai_tottle_api.py | 211 ++++--- Ai_tottle/config.yaml | 2 +- Ai_tottle/train/broken.py | 87 +++ Ai_tottle/train/let_txt_to_true.py | 85 +++ Ai_tottle/train/train.py | 25 + Ai_tottle/yolo_train.py | 524 +++++------------- 8 files changed, 482 insertions(+), 452 deletions(-) create mode 100644 Ai_tottle/aboutdataset/__pycache__/download_oss.cpython-312.pyc create mode 100644 Ai_tottle/train/broken.py create mode 100644 Ai_tottle/train/let_txt_to_true.py create mode 100644 Ai_tottle/train/train.py diff --git a/Ai_tottle/__pycache__/miniohelp.cpython-312.pyc b/Ai_tottle/__pycache__/miniohelp.cpython-312.pyc index ebdfd25231a2b54ebf13ee203dd9288210e3d3d3..87e13d05cdcd76d9157be71c7d529cc4358e40ab 100644 GIT binary patch delta 1578 zcmbVMU1%d!6rSH?@-uDHrfIuLw<$GlTP1Y^U9!uv7_nPQyLEr7LScxRsZBf0#4{6$ zn$WWNgMuJjK|%J#2Vv>HXzLhpx>bsIE%u*Y^viZv=uKL=By_aG0Kp(Zf_){Oi--IOIT-cIuEico z7luP<%kgnCZC=_aFB=UuD$ivmP#?kRet>Ww91mwdeeR1pj}p@d%j~+`);~*^r{tq4 zwLBNDitP7fs`NcN{$HlfKE;%3;;&aHF~!b@M$U!$Tq-Tajd_$EA7;2`xhu95_OX%? zHTJbKFF}<3qRg!fp%)AjqzFb3{OEecZd!`~4$P{MCjBVs#Z4f>{iyqvZB(85X6;Q^ zMFnRFvaGC*{7f>rS>E)Eu*6SYRoAHD5x(r8eVr^a5$7vsb@bvn^Og)5XsF`aMa^ zv+2Z*1V7?3YI&MaX743($pw`C$R=(u8WnE~7TI@+sp8Az_1^(v&4%q%OoO(-I{Gc) zbiZZesd8sI`Kq+PLVH&c)*w%h*9g`KUO|}EfI`A81Yh%NyO!a40E?u%O>l?cG`@yn zx41gA>hq*CQ}7@@H7cu+<-jW7+u}Jjt>i)it~ zVA(_gJv8h})q)$e;U)nmKE^ZmW1Eg`TF|A7tdzH}fa7?+X!t5I0cyoUcfzV_9*-JfL<&X3&r!rLTRB;y1;&$ zm^zb0<2yfYd0ykv{QOhqPWEg@6yvNkS=i^@wT-U9yGb7$pjQj1-FZ{+dMhmAiOO&rv^IT7hs-_rT6DVyS(d|p{8wJfp^&-llf>; r9DDL?R7^h^V`p+#(z)Yd{O6AcQ-4L&fzZL##|a^Eo7HnS6*}Y}GB01u delta 1588 zcma)5OK%%h6rSsuvB!2NcCh2LBs7llC`OPtN$Dd(0k=x4(lqn|ZGr+L*E6o&Og!ey z7(@kPEAI-fj%*`O2j5F~gzbcCt=_t3F^6&+c6 zcAqzASDFs2t`yE07x|?8LSg`8REm2GVN>Fs+?xM1Qy7#V4622(NNJxE42SU}LjJsu z@PFtUdYUf9#JAHP(RDMz?_>sNTk^W>#=U($%y5U}+iYF#bpBN1xxL0C{Gl?*HTB3w zFV+Lk5Tppw2!8Bp(Qa7t01o67=pz>^L%#^Q@7YGlsgUk*}wIZ5F$jN?m zty-%(MgwdZM~|=8i{QB~(b(M;eXeD6Ghf8$c>CwbAMB7QOD|zvo@HEblvXVdj$?SD z{cChuX5(Dd&+2(}0Ckd!zzkp37sigE;gSq-6MBO7G}CF;&Qxs+&SL0z`$zp=h>h?{ z@={VPF^gVNCcMgjNM?Iy(DtKcoM6<8UIk9`-Q@7}Da!iq4dRWu?I43R1>VHGdHme_ z{M~e+eW&+YXyqKueH~#DM(FoE!2-brgq#MnC0s)AHLtp68LkI#ntYcDt`Kzh*U%L0 zi4J)ool>u%>uP&BeNs+z_wv~ajPNh|7d94Yf&E^6rCGFBgO2Ag;zz);sZ;u?+r^Rv zZ_tE`1cKkVDA13W9ow{Ej}?kidAp~HIH`;d(`tr08a}tVGR7lwYB&6syTp-E&2dW+xflBtQn&+!cAUn^$92~uH9pg}< zBn>RL?$lfhhR9S)l}2qfI6u13K`DVgdPyJq2Pi~a$CQtF-*bPbQ&UsZ6Gx9uOyy@L zrf2fAC-eC{A38idki?W*-&H)XK07)2@1*wIhX)vo^Y@1)HiWn;m>b+wx>e$O0jXVR zTBuUU2Iv?C3%*PWpelULFlt4Vgb|#9R&Lm~s8Q%cs6uf44kmvk!ArCRU7ua;l{Lo< zH*9MTKIFOVSS-m7J|2y-%;OAS%FboRcT(A%)YwiU{UoaPgg4JTN=nJM_;&WPLW}$Z D-LYF+ diff --git a/Ai_tottle/aboutdataset/__pycache__/download_oss.cpython-312.pyc b/Ai_tottle/aboutdataset/__pycache__/download_oss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17bdff64eb5199b044751c9fd16b47ea152b97a1 GIT binary patch literal 5302 zcmbVQe{2)i9e?M$v(NU~j^j8XO{$p&2!t?ka-z9Ny?DX#3 z{MebMRXdz68;RB~ZYioat(rS1q}#M+>e>oAO_TP=;yMxTHIvYm(DbjRz{CWawC}s~ zk3b_j^;v$u-}il=-}m$RexLtpv6v8~vHzayK5jwipSY1rhHT}=D^Qt09O8%|@@OUD zA=I|cqk~cxBzq|jMQCHRhvuvt^%CLHbM#BdW8n0VGn|bxK%bFgpk?CCobkkZgoHBA zM`Fl49&cJy{rcfK-Rq}H-kV&Q5DwmzH|D~0QnT^ zC)}nsh3XEAl0x-`1;`^JFDOi(PZayZ0;e$ejym9$6g}4w^7Zm=TF}GR0{#sI250y@ z0#y6mOaN@8(fbaIef^~7wHHF{Ix>{*YW)=M_OEN0|*ue{o%f@I>i8H5aRuk zLU!?zqW4DxF)Szset`E!Bwk@U!7jRe!JtC>gJF?}wWejp!}fRc0`K()eW0&m0u_V2 z7ptdGSgv3Lbrcd7D3lQH7e%bCOG~%wxyJUU1Kz$tsXH8MZwh#&up|Zfc3($0B4tgo zoeTGef?=@Ouqf8{4Ju`-KHhLg59sXmg*b1}*TDxx;Q^40rHj9YY#7a%9aqd1DRV`} zQWn>zEZ_NBX|IU4rtH-jTlwjeBPZiMX`A~~TT9Y&@tJp?ncAPOX}M|Ev-;t^OI8!@ z%-Ee172_2@eK4la6s=A4B-*E~buscYb6JM9jUA7-$J?eI8&mAY*q*D5^|W=wIwsCA z6&aiJ^z$Rnr){nnm9aUlGVE#7i0Q}de`1tYxh-PrxMXS~nk&)e%Hrl)(!U&E_(UoS zKT=U>{6BDbUUx1lA*z#gD=&`eJj|tOysc25$DOA z9t4uG9jHkrI4ak_6-7zV(4HGBlm%wki@sk9c%@`YBt|jf9EDlAnoQ+(CD7<93KHs(dI}Nn8C|7evICavrz@(4FaYCx6=TmSqAKXE;AtBC-BT|g}&yHsM7k>HP;_u#Wbp0E1gzL*!UthdTF4k|`_1vN8!&hwB=ca1K*2%Zei{qbsx0lIY>nt zSzsqlkL3V%$*-Y}ijpVELD(Oc%h1L|9)=@mX#4o=!n<*o>L^?bKYV%N!XE%^u3Bk8 za%tF8rzu4L5FL>^YjyzY7(V0=bAeD-TLg0ehS;^lsW0CBXyN_e%&YP+%CLE=C{uf~ z6I86hsNjN8aMRUL2yVy}8lW8L6E@(Ejd+k!OsZC1i5~#0sOF84A#B1^wm>FQkYO-r zP%x?BQP&7NVPrF&B14AFpQiT^bE+X9O>UpD)Gr~zxPJ243~M>PZ)D$C>y?rTl7|sjj+nQ?HsVfkon3?xZ_eF?nFhH+6K{+H{3! zPBG1QmPkB)Tg1HfqP59^##~JXaL_+z;%bQX*+qT(%w`@UaF&^zOt=8 zsrtBrrLGuhg7p?W956eq(NYiorS3^G%2^)xksp8`!Iqx4&cbAl%*_KY`fH>-IjEqJpG>F{#wqwGB zAnG=S?LLCzL@?zqe)Yk^E6GOJ5ap`b?Q$E1D(Dy1K&BXi;B0{CF%&qlKLA#_7WY#< z;XnwyQg=iOKv>rsevVfy+D$74QSf{55(5W*N?nUhT6Gl+9@AJK58|4L-vrpUv&L=3 zvb><37q1e9$6)vo_=+z;1|c1#KXvehgJXuonpvuPuB3FLb-Xp+pDwAsj%d0OxE{xR zl__&&{Cl(JHFIU<6aC};6Y{v6*q3}VUA8T@f5kqw`1HY%gK;LYeumvR$J)mZCYbn8 zvLxA^^h_7;m}Pf~OxAJa`Q?Uj|Jmf-vt>O1pYWoHm`7jaRwU%U^Y|X(fTJ$Z}ld_FthCo{1B zGArA2qY##+D75TXHs;O-PkaQ=#|UPPg_qyE_SV~r7k_~rTm(n$u4`w1zIb7B;ltN7 z{3}THT8r=hcJbp=*CrC0)6!Z!R7NIU&7`+x7+Z4)$lWNm#5wI<* zFE5n=6cVsrU=RVvG8$(v_~n|VfxaMwhXazZ8;{};PIY8%hhiM`^#;|)qfi3fiZRgZ z>*7@!F=}2C7)8KWB;dOiR^dsy0H+wjLKiq|h&g*9hnFf2UEwvVJrdPzX;lEcW+TC% z8awX6i_Lh<8wz`|_cdwTsT7s<%4*#pM*MzW6h*wL%2s!0nqyv_k5TcPRQ1xTJW&f* z;q@v!4eMLri~s7tbY^!Y%G36Z!>#k?hZ2oRdCHSE?;qYfM;T*>-mHpqXWa3MMC-I^ z(+stFKFfW^#NJuTonu0+b#=$C>&QlLNHGmDU2Naj6Y&-h7s;55VvojK-a7dD!83=F zM?S6GG|$?nOKOrG$z#*Sbu(;zhAp{T>Kbd0F>?;*#P;#+R~%I-M^&OV?brl;rK={o z#=EYRu1l4!OAMq-H&5DArS&oMyt8cL`1tXO_VMQtO3%_(mA7x7Qa4dA;yY zg!=c?P_25Wu18(fn|Cy9t>3cMSF^dJ-e1#Dzmu!k$v4#1)NO0Zd4YZGu_;MOU$IwCxQTXunkLbKJ?cw=L4AQT99^T9sNsK1Mc@qY>h6IK0@ zdW%?1)pWeVAv^_}tI&qC!#Kkx`VLYLq#DOhI5d3^)U)*FdE#Ent|(RjV2oNzDgKolDr`jjlsSWAZwWb94Dk83rp gSvlPL8MQVMn5>){{7d=tv1eu}?~(;k52%X#51L8hCIA2c literal 0 HcmV?d00001 diff --git a/Ai_tottle/ai_tottle_api.py b/Ai_tottle/ai_tottle_api.py index 583f158..9fa29de 100644 --- a/Ai_tottle/ai_tottle_api.py +++ b/Ai_tottle/ai_tottle_api.py @@ -7,8 +7,15 @@ from sanic_cors import CORS # ourself imports from ai_image import process_images from map_find import map_process_images -from yolo_train import auto_train,query_progress +from yolo_train import train_main from yolo_photo import map_process_images_with_progress +from pydantic import BaseModel, ValidationError +from typing import List, Dict +import threading +import torch +import uuid +from queue import Queue + # set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -27,7 +34,7 @@ async def token_and_resource_check(request): # --- GPU 使用率检查 --- try: if torch.cuda.is_available(): - num_gpus = torch.cuda.device_count() + num_gpus = torch.cuda.device_count() max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90% for i in range(num_gpus): @@ -237,83 +244,137 @@ async def yolo_detect_api(request): "message": f"Internal server error: {str(e)}" }, status=500) -# YOLO auto_train API +#--------------------------------------------------------------------------yolo训练相关的API----------------------------------------------------------------######################################## +#创建yolo训练的蓝图 + +MAX_CONCURRENT_JOBS = torch.cuda.device_count() if torch.cuda.is_available() else 1 +tasks: Dict[str, Dict] = {} +task_queue = Queue() +active_jobs: List[str] = [] +lock = threading.Lock() + + +# ------------------ 参数模型 ------------------ +class TrainRequest(BaseModel): + config_name: str + table_name: str + column_name: str + search_condition: str + aim_path: str + image_dir: str + label_dir: str + output_path: str + pt_path: str + imgsz: int + epochs: int + device: List[int] + hsv_v: float + cos_lr: bool + batch: int + project_dir: str + class_names: List[str] + + +# ------------------ 核心执行函数 ------------------ +def run_training(task_id: str, params: TrainRequest): + try: + with lock: + active_jobs.append(task_id) + tasks[task_id]["status"] = "running" + + train_main( + config_name=params.config_name, + table_name=params.table_name, + column_name=params.column_name, + search_condition=params.search_condition, + aim_path=params.aim_path, + image_dir=params.image_dir, + label_dir=params.label_dir, + output_path=params.output_path, + pt_path=params.pt_path, + imgsz=params.imgsz, + epochs=params.epochs, + device=params.device, + hsv_v=params.hsv_v, + cos_lr=params.cos_lr, + batch=params.batch, + project_dir=params.project_dir, + class_names=params.class_names + ) + + tasks[task_id]["status"] = "finished" + except Exception as e: + tasks[task_id]["status"] = "failed" + tasks[task_id]["error"] = str(e) + finally: + with lock: + if task_id in active_jobs: + active_jobs.remove(task_id) + schedule_next_job() + + +# ------------------ 调度器 ------------------ +def schedule_next_job(): + with lock: + while len(active_jobs) < MAX_CONCURRENT_JOBS and not task_queue.empty(): + next_id = task_queue.get() + params = tasks[next_id]["params"] + t = threading.Thread(target=run_training, args=(next_id, params), daemon=True) + t.start() +# ------------------ 接口 ------------------ @yolo_tile_blueprint.post("/train") -async def yolo_train_api(request): - """ - auto_train - input JSON: - { - "db_host": str, - "db_database": str, - "db_user": str, - "db_password": str, - "db_port": int, - "model_id": int, - "img_path": str, - "label_path": str, - "new_path": str, - "split_list": List[float], - "class_names": Optional[List[str]], - "project_name": str - } - output JSON: - return { - "status": "success", - "message": "Train finished", - "project_name": project_name, - "label_count": label_count, - "base_metrics": base_metrics, - "final_metrics": final_metrics - } - """ +async def submit_train_job(request): try: data = request.json - if not data: - return json_response({"status": "error", "message": "data is required"}, status=400) - # Do the training in a separate thread to avoid blocking the event loop - result = await asyncio.to_thread( - auto_train, - data - ) - # return the result as JSON response - return json_response(result) - - except Exception as e: - logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True) - return json_response({ - "status": "error", - "message": f"Internal server error: {str(e)}" - }, status=500) - -# access the training progress -@yolo_tile_blueprint.get("/progress/") -async def yolo_train_progress(request, project_name): - ''' - input: - if want to query the latest progress: GET /yolo/progress/my_project - if want to query the progress at a specific time: GET /yolo/progress/my_project?run_time=20250902_1012 - output JSON: - { - "status": "ok", - "run_time": "20250902_1012", - "progress": { - "epoch": 12, - "precision": 0.72, - "recall": 0.64, - "mAP50": 0.68, - "mAP50-95": 0.42 - } - } - ''' - run_time = request.args.get("run_time") # get the run_time from the query string - # query the progress from the database - if not run_time: - run_time = None # if not provided, query the latest progress - - result = await asyncio.to_thread(query_progress, project_name, run_time) - return json_response(result) + params = TrainRequest(**data) + except ValidationError as e: + return json({"success": False, "error": e.errors()}) + task_id = str(uuid.uuid4()) + tasks[task_id] = {"status": "queued", "params": params} + + with lock: + if len(active_jobs) < MAX_CONCURRENT_JOBS: + t = threading.Thread(target=run_training, args=(task_id, params), daemon=True) + t.start() + else: + task_queue.put(task_id) + tasks[task_id]["status"] = "waiting" + + return json({"success": True, "task_id": task_id, "message": "任务已提交"}) + + +@yolo_tile_blueprint.get("/task_status/") +async def task_status(request, task_id: str): + if task_id not in tasks: + return json({"success": False, "message": "任务ID不存在"}) + + task_info = tasks[task_id] + return json({ + "success": True, + "status": task_info["status"], + "error": task_info.get("error", None) + }) + + +@yolo_tile_blueprint.get("/tasks") +async def all_tasks(request): + return json({ + tid: {"status": info["status"]} + for tid, info in tasks.items() + }) + + +@yolo_tile_blueprint.get("/system_status") +async def system_status(request): + gpu_available = torch.cuda.is_available() + return json({ + "gpu_available": gpu_available, + "max_concurrent": MAX_CONCURRENT_JOBS, + "running_jobs": len(active_jobs), + "waiting_jobs": task_queue.qsize(), + "active_task_ids": active_jobs + }) if __name__ == '__main__': app.run(host="0.0.0.0", port=12366, debug=True,workers=1) \ No newline at end of file diff --git a/Ai_tottle/config.yaml b/Ai_tottle/config.yaml index 4ea0d31..2fa2eda 100644 --- a/Ai_tottle/config.yaml +++ b/Ai_tottle/config.yaml @@ -8,6 +8,6 @@ minio: sql: host: '222.212.85.86' port: 5432 - dbname: 'postgres' + dbname: 'smart_dev' user: 'postgres' password: 'root' \ No newline at end of file diff --git a/Ai_tottle/train/broken.py b/Ai_tottle/train/broken.py new file mode 100644 index 0000000..992bfd8 --- /dev/null +++ b/Ai_tottle/train/broken.py @@ -0,0 +1,87 @@ +import os +import shutil +import random +from tqdm import tqdm +import yaml + +def split_img(img_path, label_path, split_list, output_path,class_names=[ + 'people', + 'car', + 'truck', + 'bicycle', + 'tricycle', + 'ship']): + try: + # 创建目标目录结构 + for sub in ['images/train', 'images/val', 'images/test', + 'labels/train', 'labels/val', 'labels/test']: + os.makedirs(os.path.join(output_path, sub), exist_ok=True) + except Exception as e: + print(f'❌ 文件目录创建失败: {e}') + return + + train, val, test = split_list + all_imgs = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.png'))] + all_img_paths = [os.path.join(img_path, f) for f in all_imgs] + + # 分配训练集 + train_imgs = random.sample(all_img_paths, int(train * len(all_img_paths))) + move_set(train_imgs, label_path, os.path.join(output_path, 'images/train'), os.path.join(output_path, 'labels/train')) + for f in train_imgs: all_img_paths.remove(f) + + # 分配验证集 + val_imgs = random.sample(all_img_paths, int(val / (val + test) * len(all_img_paths))) + move_set(val_imgs, label_path, os.path.join(output_path, 'images/val'), os.path.join(output_path, 'labels/val')) + for f in val_imgs: all_img_paths.remove(f) + + # 剩余分配给测试集 + test_imgs = all_img_paths + move_set(test_imgs, label_path, os.path.join(output_path, 'images/test'), os.path.join(output_path, 'labels/test')) + + # 生成 dataset.yaml + generate_yaml(output_path, class_names) + +def move_set(img_list, label_root, dst_img_dir, dst_label_dir): + for img_path in tqdm(img_list, desc=f'Copying to {os.path.basename(dst_img_dir)}', ncols=80): + base = os.path.splitext(os.path.basename(img_path))[0] + label_path = os.path.join(label_root, base + '.txt') + + shutil.copy(img_path, os.path.join(dst_img_dir, os.path.basename(img_path))) + if os.path.exists(label_path): + shutil.copy(label_path, os.path.join(dst_label_dir, base + '.txt')) + +def generate_yaml(dataset_root, class_names): + yaml_content = { + 'train': os.path.join('images/train'), + 'val': os.path.join('images/val'), + 'test': os.path.join('images/test'), + 'nc': len(class_names), + 'names': class_names + } + with open(os.path.join(dataset_root, 'dataset.yaml'), 'w') as f: + yaml.dump(yaml_content, f, default_flow_style=False) + print(f"✅ 已生成 YAML: {os.path.join(dataset_root, 'dataset.yaml')}") + +def broken_main(aim_path, output_path,class_names=[ + 'people', + 'car', + 'truck', + 'bicycle', + 'tricycle', + 'ship']): + img_path = os.path.join(aim_path, 'images') + label_path = os.path.join(aim_path, 'labels') + split_ratio = [0.7, 0.2, 0.1] + split_img(img_path, label_path, split_ratio, output_path,class_names) + +if __name__ == '__main__': + broken_main( + r"D:\Users\76118\Downloads\stanford_campus_dataset\filtered", + r"D:\work\develop\AI\数据集\output", + class_names=[ + 'people', + 'car', + 'truck', + 'bicycle', + 'tricycle',] + ) diff --git a/Ai_tottle/train/let_txt_to_true.py b/Ai_tottle/train/let_txt_to_true.py new file mode 100644 index 0000000..9e51b42 --- /dev/null +++ b/Ai_tottle/train/let_txt_to_true.py @@ -0,0 +1,85 @@ +import os +import stat +import math + + +def make_writable(file_path): + os.chmod(file_path, stat.S_IWRITE) + + +def process_files_in_folder(folder_path): + for root, _, files in os.walk(folder_path): + for file_name in files: + if file_name.endswith(".txt"): + file_path = os.path.join(root, file_name) + + # 确保文件可写 + make_writable(file_path) + + # 读取文件内容并进行处理 + with open(file_path, "r") as file: + lines = file.readlines() + + processed_lines = [] + for line in lines: + numbers = line.split() + processed_numbers = [] + + # 确保第一列为整数 0 或 1,不处理为浮点数 + if ( + numbers[0] == "0" + or numbers[0] == "1" + or numbers[0] == "2" + or numbers[0] == "3" + or numbers[0] == "4" + or numbers[0] == "5" + or numbers[0] == "6" + or numbers[0] == "7" + or numbers[0] == "8" + or numbers[0] == "9" + or numbers[0] == "10" + or numbers[0] == "11" + or numbers[0] == "12" + or numbers[0] == "13" + or numbers[0] == "14" + or numbers[0] == "15" + or numbers[0] == "16" + or numbers[0] == "17" + or numbers[0] == "18" + ): + processed_numbers.append(numbers[0]) + else: + print(f"Unexpected value in first column: {numbers[0]}") + continue + + # 处理后面的列,保留原始格式并确保负数变成正数,且删除 NaN 数据 + skip_line = False # 用于标记是否跳过这一行 + for number in numbers[1:]: + try: + number = float(number) + if math.isnan(number): # 检查是否为NaN + skip_line = True + print( + f"NaN detected in file: {file_path}, line: {line}" + ) + break + if number < 0: + number = abs(number) # 将负数转换为正数 + processed_numbers.append(str(number)) # 保留原始格式 + except ValueError: + processed_numbers.append(number) # 非数字列保持原样 + + # 如果该行没有NaN数据,则加入结果列表 + if not skip_line: + processed_line = " ".join(processed_numbers) + processed_lines.append(processed_line) + + # 将处理后的内容写回文件 + with open(file_path, "w") as file: + file.write("\n".join(processed_lines)) + print(f"Finished processing: {file_path}") + +# 指定文件夹路径 +folder_path = r"G:\dataset\PCS\before\labels" +#run the function +process_files_in_folder(folder_path) diff --git a/Ai_tottle/train/train.py b/Ai_tottle/train/train.py new file mode 100644 index 0000000..757c67d --- /dev/null +++ b/Ai_tottle/train/train.py @@ -0,0 +1,25 @@ +from ultralytics import YOLO +import torch + + +# 检查CUDA是否可用 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# 加载模型 +model = YOLO("runs/detect/train6/weights/last.pt").to(device) + +# 设置新的分辨率 +imgsz = 1024 # 这里将图像尺寸调整为 1280x1280(你可以根据显存调整尺寸) + +# 训练模型,传入增强参数 +model.train( + data="dataset/dataset.yaml", # 你的数据集配置文件 + epochs=1000, # 训练轮次 + imgsz=imgsz, # 使用更高的分辨率 + device=[1], # 使用第一块 GPU(如果有多个 GPU,可以调整) + hsv_v=0.3, # 修改图像亮度的一部分,帮助模型在不同光照条件下表现良好 + cos_lr=True, # 启用余弦学习率调度 + batch = -1, # 自动调整批量大小以适应显存 +) + diff --git a/Ai_tottle/yolo_train.py b/Ai_tottle/yolo_train.py index e654740..c0abf8d 100644 --- a/Ai_tottle/yolo_train.py +++ b/Ai_tottle/yolo_train.py @@ -1,394 +1,166 @@ -"""" -main() - | - v -setup_logger(project) - | - v -get_last_model_path(project) - | - v -+-------------------------+ -| 有 last.pt | 无 last.pt | -+-------------------------+ - | | - v v -load_last_model() start_new_training() - | | - +-------+--------+ - | - v - check_dataset(root) - | - v - split_dataset(root, ratios) - | - v - clean_labels(root) - | - v - generate_yaml(dataset_dir) - | - v - train_yolo(model, data_yaml) - | - v - 保存 last.pt - | - v -logger.info("Saved last model path") - | - v - 写入 logs/{project}.log - - - -""" - import os -import shutil -import datetime -import torch -from ultralytics import YOLO -import random -import math -import stat -import yaml -import psycopg2 -from psycopg2 import OperationalError -from collections import Counter -import pandas as pd -import logging -from tqdm import tqdm -import miniohelp as miniohelp +from glob import glob from aboutdataset.download_oss import download_and_save_images_from_oss +from train.let_txt_to_true import process_files_in_folder +from train.broken import broken_main +from ultralytics import YOLO +import torch -######################################## Logging ######################################## -def setup_logger(project: str): - os.makedirs("logs", exist_ok=True) - log_file = os.path.join("logs", f"{project}.log") - logger = logging.getLogger(project) - if not logger.handlers: - logger.setLevel(logging.INFO) - formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") +# ------------------ 下载图片和标签 ------------------ +def download_images_and_labels( + config_name, # OSS 配置文件名,用于读取连接信息 + table_name, # OSS 表名,指定下载数据的表 + column_name, # OSS 表中图片 URL 列名 + search_condition, # 筛选条件,用于查询 OSS 数据 + aim_path, # 本地保存数据集根目录 + image_dir, # 本地保存图片的目录 + label_dir # 本地保存标签 txt 的目录 +): + os.makedirs(aim_path, exist_ok=True) + os.makedirs(image_dir, exist_ok=True) + os.makedirs(label_dir, exist_ok=True) - fh = logging.FileHandler(log_file, encoding="utf-8") - fh.setFormatter(formatter) - sh = logging.StreamHandler() - sh.setFormatter(formatter) - - logger.addHandler(fh) - logger.addHandler(sh) - - return logger - -def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"): - """ - 从日志解析上一次训练的 last.pt 路径 - 如果找不到则返回 default_model - 支持 default_model 为 .pt 或 .yaml - """ - log_file = os.path.join("logs", f"{project}.log") - if not os.path.exists(log_file): - return default_model - - with open(log_file, "r", encoding="utf-8") as f: - lines = f.readlines() - - for line in reversed(lines): - if "Saved last model path:" in line: - path = line.strip().split("Saved last model path:")[-1].strip() - if os.path.exists(path): - return path - - return default_model - -def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name): - """ - yaml_name='config' - where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'" - image_dir='images' - label_dir='labels' - table_name = 'aidataset' - - Returns: - (image_dir, label_dir) - """ - download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name) - return image_dir, label_dir - -####################################### 工具函数 ####################################### -def count_labels_by_class(label_dir): - class_counter = Counter() - for file in os.listdir(label_dir): - if file.endswith('.txt'): - with open(os.path.join(label_dir, file), 'r') as f: - for line in f: - class_id = line.strip().split()[0] - class_counter[class_id] += 1 - return dict(class_counter) - -def evaluate_model_per_class(model_path, dataset_yaml, class_names): - model = YOLO(model_path) - metrics = model.val(data=dataset_yaml, split='val') - class_ids = range(len(metrics.box.p)) - results = {} - for i in class_ids: - name = class_names.get(str(i), str(i)) - results[name] = { - "precision": float(metrics.box.p[i]), - "recall": float(metrics.box.r[i]), - "mAP50": float(metrics.box.map50[i]), - "mAP50_95": float(metrics.box.map[i]) - } - return results - -def link_database(db_database, db_user, db_password, db_host, db_port, search_query): - try: - with psycopg2.connect( - database=db_database, - user=db_user, - password=db_password, - host=db_host, - port=db_port - ) as conn: - with conn.cursor() as cur: - cur.execute(search_query) - records = cur.fetchall() - return records - except OperationalError as e: - print(f"数据库连接或查询时发生错误: {e}") - except Exception as e: - print(f"发生了其他错误: {e}") - -def down_dataset(db_database, db_user, db_password, db_host, db_port, model, logger): - search_query = f"SELECT * FROM aidataset WHERE model = '{model}';" - records = link_database(db_database, db_user, db_password, db_host, db_port, search_query) - if not records: - logger.warning("没有查询到数据。") - return - - os.makedirs('./dataset/images', exist_ok=True) - os.makedirs('./dataset/labels', exist_ok=True) - - for r in records: - img_path = r[4] - label_content = r[5] - - local_img_name = img_path.split('/')[-1] - local_img_path = os.path.join('./dataset/images', local_img_name) - miniohelp.downFile(img_path, local_img_path) - - txt_name = os.path.splitext(local_img_name)[0] + '.txt' - txt_path = os.path.join('./dataset/labels', txt_name) - with open(txt_path, 'w', encoding='utf-8') as f: - f.write(label_content + '\n') - - logger.info("数据下载完成") - -def make_writable(file_path): - os.chmod(file_path, stat.S_IWRITE) - -def process_files_in_folder(folder_path, logger): - for root, _, files in os.walk(folder_path): - for file_name in files: - if file_name.endswith('.txt'): - file_path = os.path.join(root, file_name) - make_writable(file_path) - - with open(file_path, 'r') as file: - lines = file.readlines() - - processed_lines = [] - for line in lines: - numbers = line.split() - processed_numbers = [] - if numbers[0].isdigit(): - processed_numbers.append(numbers[0]) - else: - logger.warning(f"Unexpected value in first column: {numbers[0]}") - continue - - skip_line = False - for number in numbers[1:]: - try: - number = float(number) - if math.isnan(number): - skip_line = True - logger.warning(f"NaN detected in {file_path}: {line}") - break - if number < 0: - number = abs(number) - processed_numbers.append(str(number)) - except ValueError: - processed_numbers.append(number) - - if not skip_line: - processed_line = ' '.join(processed_numbers) - processed_lines.append(processed_line) - - with open(file_path, 'w') as file: - file.write('\n'.join(processed_lines)) - logger.info(f"Processed {file_path}") - -def split_img(img_path, label_path, split_list, new_path, class_names, logger): - try: - Data = os.path.abspath(new_path) - os.makedirs(Data, exist_ok=True) - dirs = ['train/images','val/images','test/images','train/labels','val/labels','test/labels'] - for d in dirs: os.makedirs(os.path.join(Data, d), exist_ok=True) - except Exception as e: - logger.error(f'文件目录创建失败: {e}') - return - - train, val, test = split_list - all_img = os.listdir(img_path) - all_img_path = [os.path.join(img_path, img) for img in all_img] - - train_img = random.sample(all_img_path, int(train * len(all_img_path))) - train_label = [toLabelPath(img, label_path) for img in train_img] - for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'): - _copy(train_img[i], os.path.join(Data,'train/images')) - _copy(train_label[i], os.path.join(Data,'train/labels')) - all_img_path.remove(train_img[i]) - - val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path))) - val_label = [toLabelPath(img, label_path) for img in val_img] - for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'): - _copy(val_img[i], os.path.join(Data,'val/images')) - _copy(val_label[i], os.path.join(Data,'val/labels')) - all_img_path.remove(val_img[i]) - - test_img = all_img_path - test_label = [toLabelPath(img, label_path) for img in test_img] - for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'): - _copy(test_img[i], os.path.join(Data,'test/images')) - _copy(test_label[i], os.path.join(Data,'test/labels')) - - generate_dataset_yaml( - save_path=os.path.join(Data, 'dataset.yaml'), - train_path=os.path.join(Data,'train/images'), - val_path=os.path.join(Data,'val/images'), - test_path=os.path.join(Data,'test/images'), - class_names=class_names + download_and_save_images_from_oss( + yaml_name=config_name, + where_clause=f"{column_name} = '{search_condition}'", + image_dir=image_dir, + label_dir=label_dir, + table_name=table_name, ) - logger.info("数据集划分完成") -def _copy(from_path, to_path): - try: - shutil.copy(from_path, to_path) - except Exception as e: - print(f"复制文件时出错: {e}") + return aim_path, image_dir, label_dir -def toLabelPath(img_path, label_path): - img = os.path.basename(img_path) - label = img.replace('.jpg', '.txt') - return os.path.join(label_path, label) -def generate_dataset_yaml(save_path, train_path, val_path, test_path, class_names): - dataset_yaml = { - 'train': train_path.replace('\\', '/'), - 'val': val_path.replace('\\', '/'), - 'test': test_path.replace('\\', '/'), - 'nc': len(class_names), - 'names': list(class_names.values()) - } - with open(save_path, 'w', encoding='utf-8') as f: - yaml.dump(dataset_yaml, f, allow_unicode=True) +# ------------------ 标签修正与数据打乱 ------------------ +def broken_and_convert_txt_to_yolo_format( + aim_path, # 数据集根目录 + output_path, # 打乱并输出后的数据集目录 + image_dir, # 图片目录 + label_dir, # 标签目录 + class_names # 数据集类别列表 +): + process_files_in_folder(label_dir) # 修正标签为 YOLO 格式 + broken_main(aim_path, output_path, class_names) # 打乱数据集并生成 dataset.yaml + yaml_path = os.path.join(output_path, 'dataset.yaml') + return output_path, yaml_path -def delete_folder(folder_path, logger): - if os.path.exists(folder_path): - shutil.rmtree(folder_path) - logger.info(f"已删除文件夹: {folder_path}") -####################################### 训练 ####################################### -def train(project_name, yaml_path, default_model_path, logger): +# ------------------ 获取最新 pt 模型 ------------------ +def get_latest_pt(project_dir, pt_path): + """ + 检查指定训练输出目录是否有最新 .pt 模型文件。 + 若存在则返回最新文件路径,否则返回传入的 pt_path。 + """ + if not os.path.exists(project_dir): + print(f"[INFO] 项目目录 {project_dir} 不存在,使用传入模型 {pt_path}") + return pt_path + + pt_files = glob(os.path.join(project_dir, "*.pt")) + if not pt_files: + print(f"[INFO] 目录中无 pt 文件,使用传入模型 {pt_path}") + return pt_path + + latest_pt = max(pt_files, key=os.path.getmtime) + print(f"[INFO] 检测到最新模型: {latest_pt}") + return latest_pt + + +# ------------------ 训练 ------------------ +def train( + yaml_path, # YOLO 数据集配置文件路径 + pt_path, # 用于训练的初始权重 .pt 文件路径 + imgsz, # 输入图片分辨率 + epochs, # 训练轮次 + device, # GPU 设备索引列表,例如 [0] 或 [0,1] + hsv_v, # 图像亮度增强系数 + cos_lr, # 是否使用余弦学习率 + batch, # 批量大小 + project_dir # 训练输出目录(模型权重、日志等) +): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Using device: {device}") + print(f"[INFO] Using device: {device}") - model_path = get_last_model_from_log(project_name, default_model_path) - logger.info(f"加载模型: {model_path}") - model = YOLO(model_path).to(device) + pt_path = get_latest_pt(project_dir, pt_path) # 自动检测最新 pt 文件 + + model = YOLO(pt_path).to(device) - current_date = datetime.datetime.now().strftime("%Y%m%d_%H%M") model.train( data=yaml_path, - epochs=200, - pretrained=True, - patience=50, - imgsz=640, + epochs=epochs, + imgsz=imgsz, + device=device, + hsv_v=hsv_v, + cos_lr=cos_lr, + batch=batch, + project=project_dir, + ) + + +# ------------------ 主流程 ------------------ +def train_main( + # OSS 下载参数 + config_name, # sql 配置文件名 + table_name, # sql 表名 + column_name, # sql 表中列名 + search_condition, # sql 数据筛选条件 + # 数据集路径 + aim_path, # 本地数据集根目录,打乱后的 + image_dir, # 本地图片保存目录 + label_dir, # 本地标签保存目录 + output_path, # 打乱并输出后的数据集目录 + # YOLO 训练参数 + pt_path, # 初始权重文件路径 + imgsz, # 输入图片分辨率 + epochs, # 训练轮次 + device, # GPU 设备索引列表 + hsv_v, # 图像亮度增强系数 + cos_lr, # 是否使用余弦学习率 + batch, # 批量大小 + project_dir, # 训练输出目录 + # 类别 + class_names # 数据集类别列表 +): + aim_path, image_dir, label_dir = download_images_and_labels( + config_name, table_name, column_name, search_condition, + aim_path, image_dir, label_dir + ) + + output_path, yaml_path = broken_and_convert_txt_to_yolo_format( + aim_path, output_path, image_dir, label_dir, class_names + ) + + train( + yaml_path=yaml_path, + pt_path=pt_path, + imgsz=imgsz, + epochs=epochs, + device=device, + hsv_v=hsv_v, + cos_lr=cos_lr, + batch=batch, + project_dir=project_dir + ) + + +# ------------------ 执行 ------------------ +if __name__ == "__main__": + train_main( + config_name="config", + table_name="aidataset", + column_name="image_url", + search_condition="your_search_id", + aim_path="./datasets/aidataset_dataset", + image_dir="./dataset/aidataset_dataset_images", + label_dir="./dataset/aidataset_dataset_labels", + output_path="./my_dataset", + pt_path="custom_model.pt", + imgsz=800, + epochs=500, device=[0], - workers=0, - project=project_name, - name=current_date, + hsv_v=0.3, + cos_lr=True, + batch=8, + project_dir="./my_train_runs", + class_names=['person','car'] ) - - trained_model_path = os.path.join('runs', 'detect', project_name, current_date, 'weights', 'last.pt') - if os.path.exists(trained_model_path): - logger.info(f"Saved last model path: {trained_model_path}") - -####################################### 自动训练 ####################################### -def auto_train(db_host, db_database, db_user, db_password, db_port, model_id, - img_path='./dataset/images', label_path='./dataset/labels', - new_path='./datasets', split_list=[0.7, 0.2, 0.1], - class_names=None, project_name='default_project'): - if class_names is None: - class_names = {} - - logger = setup_logger(project_name) - - delete_folder('dataset', logger) - delete_folder('datasets', logger) - - down_dataset(db_database, db_user, db_password, db_host, db_port, model_id, logger) - process_files_in_folder(img_path, logger) - - label_count = count_labels_by_class(label_path) - logger.info(f"标签统计: {label_count}") - - split_img(img_path, label_path, split_list, new_path, class_names, logger) - - base_metrics = evaluate_model_per_class('yolo11n.pt', './datasets/dataset.yaml', class_names) - logger.info(f"训练前基线评估: {base_metrics}") - - delete_folder('dataset', logger) - - train(project_name, './datasets/dataset.yaml', 'yolo11n.pt', logger) - - logger.info("训练流程执行完成") - -def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name): - - imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name) - auto_train( - db_host=db_host, - db_database=db_database, - db_user=db_user, - db_password=db_password, - db_port=db_port, - model_id=model_id, - imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递 - label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递 - new_path='./datasets', - split_list=[0.7, 0.2, 0.1], - class_names={'0': 'human', '1': 'car'}, - project_name='my_project' - ) - -####################################### 主入口 ####################################### -if __name__ == '__main__': - down_and_train( - db_host='222.212.85.86', - db_database='your_database_name', - db_user='postgres', - db_password='postgres', - db_port='5432', - model_id='best.pt', - img_path='./dataset/images', #before broken img path - label_path='./dataset/labels',#before broken labels path - new_path='./datasets', #after broken path - split_list=[0.7, 0.2, 0.1], - class_names={'0': 'human', '1': 'car'}, - project_name='my_project' -) \ No newline at end of file