From 06bafccb4e2377a669b31f533d7d92a876bd58ca Mon Sep 17 00:00:00 2001 From: liyubo Date: Thu, 13 Nov 2025 10:29:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=B1=E4=BA=AB=E7=9B=AE=E5=BD=95AI=E8=AF=86?= =?UTF-8?q?=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- predict/predict_yolo11seg.py | 281 +++++++++++- util/smb.py | 859 +++++++++++++++++++++++++++++++++++ util/yolo2pix_new.py | 243 ++++++++++ yolo_api.py | 47 ++ 4 files changed, 1423 insertions(+), 7 deletions(-) create mode 100644 util/smb.py create mode 100644 util/yolo2pix_new.py diff --git a/predict/predict_yolo11seg.py b/predict/predict_yolo11seg.py index 7c0a0fe..1f2e6a2 100644 --- a/predict/predict_yolo11seg.py +++ b/predict/predict_yolo11seg.py @@ -17,7 +17,9 @@ import matplotlib.pyplot as plt from ultralytics import YOLO from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder -from util.yolo2pix import yoloseg_to_grid_cells_fixed_v5, draw_grid_on_image +from util.yolo2pix_new import * +from util.smb import * +import threading # 定义红白蓝颜色 (BGR格式) RED = (0, 0, 255) @@ -183,6 +185,71 @@ class YOLOSegmentationInference: print(f"图片预处理失败: {e}") return None, None + def perform_inference_share_dir(self, image, image_path, conf_threshold: float = 0.25, + iou_threshold: float = 0.5) -> InferenceResult: + """ + 执行推理 + + Args: + image_path: 图片数据 + image_path: 图片路径 + conf_threshold: 置信度阈值 + iou_threshold: IOU阈值 + + Returns: + 推理结果 + """ + result = InferenceResult(image_path) + + try: + if self.model is None: + raise ValueError("模型未加载,请先调用load_model()") + + # 转换为RGB格式 + original_image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + result.original_image = original_image_rgb + + # 执行推理 + print(f"正在处理图片: {os.path.basename(image_path)}") + start_time = time.time() + + # 使用YOLO模型进行推理 + # predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0] + predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0] + + result.inference_time = time.time() - start_time + + # 处理结果 + if predictions.masks is not None: + # 处理掩码 + masks = predictions.masks.data.cpu().numpy() + + # 处理边界框 + boxes = predictions.boxes.data.cpu().numpy() + + # 处理类别和置信度 + class_ids = predictions.boxes.cls.cpu().numpy().astype(int) + scores = predictions.boxes.conf.cpu().numpy() + + # 获取类别名称 + class_names = [self.model.names[i] for i in class_ids] + + # 存储结果 + result.masks = masks + result.boxes = boxes + result.classes = class_ids + result.scores = scores + result.class_names = class_names + + print(f"检测到 {len(masks)} 个对象,推理时间: {result.inference_time:.3f} 秒") + + return result + + except Exception as e: + print(f"推理失败: {e}") + return result + def perform_inference(self, image_path: str, conf_threshold: float = 0.25, iou_threshold: float = 0.5) -> InferenceResult: """ @@ -323,7 +390,7 @@ class YOLOSegmentationInference: try: base_name = os.path.splitext(os.path.basename(result.image_path))[0] - output_dir = output_dir + "-" + base_name + # output_dir = output_dir + "/" + base_name # 创建输出目录 os.makedirs(output_dir, exist_ok=True) @@ -334,14 +401,15 @@ class YOLOSegmentationInference: "result_dir": output_dir } # 保存结果图片 - result_path = os.path.join(output_dir, f"{base_name}_result.jpg") + result_path = os.path.join(output_dir, f"{base_name}.jpg") result_image_bgr = cv2.cvtColor(result.result_image, cv2.COLOR_RGB2BGR) cv2.imwrite(result_path, result_image_bgr) print(f"结果图片已保存: {result_path}") # 保存单独的掩码文件 if save_mask and result.masks is not None and len(result.masks) > 0: - mask_dir = os.path.join(output_dir, "masks") + # mask_dir = os.path.join(output_dir, "masks") + mask_dir = output_dir os.makedirs(mask_dir, exist_ok=True) for i in range(len(result.masks)): @@ -356,7 +424,8 @@ class YOLOSegmentationInference: # 保存YOLO格式的标签文件 if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0: - label_dir = os.path.join(output_dir, "labels") + # label_dir = os.path.join(output_dir, "labels") + label_dir = output_dir os.makedirs(label_dir, exist_ok=True) label_path = os.path.join(label_dir, f"{base_name}.txt") @@ -463,6 +532,45 @@ class YOLOSegmentationInference: # self.show_results(result) return result + + def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None, + conf_threshold: float = 0.25, iou_threshold: float = 0.5, + save_mask: bool = False, save_label: bool = False, show: bool = True, + result_save: [] = None) -> InferenceResult: + """ + 处理单张图片 + + Args: + image_path: 图片路径 + output_dir: 输出目录,如果为None则不保存 + conf_threshold: 置信度阈值 + iou_threshold: IOU阈值 + save_mask: 是否保存单独的掩码文件 + save_label: 是否保存YOLO格式的标签文件 + show: 是否显示结果 + + Returns: + 推理结果 + """ + # 执行推理 + config = get_conf(image_path, user_name, pwd) + scanner = get_scanner(image_path, user_name=user_name, pwd=pwd) + image = scanner.read_img_file(image_path) + result = self.perform_inference_share_dir(image, image_path, conf_threshold, iou_threshold) + + # 绘制结果 + if result.masks is not None and len(result.masks) > 0: + self.draw_results(result, conf_threshold) + + # 保存结果 + if output_dir is not None: + self.save_results(result, output_dir, save_mask, save_label, result_save) + + # # 显示结果 + # if show: + # self.show_results(result) + + return result def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, @@ -528,6 +636,83 @@ class YOLOSegmentationInference: print(f"处理目录失败: {e}") return results + def process_image_directory_share_dir(self, input_dir, user_name, pwd, output_dir: Optional[str] = None, + conf_threshold: float = 0.25, iou_threshold: float = 0.5, + save_mask: bool = False, save_label: bool = False, show: bool = False, + result_save: [] = None) -> List[ + InferenceResult]: + """ + 处理目录中的所有图片 + + Args: + input_dir: 输入目录 + output_dir: 输出目录,如果为None则不保存 + conf_threshold: 置信度阈值 + iou_threshold: IOU阈值 + save_mask: 是否保存单独的掩码文件 + save_label: 是否保存YOLO格式的标签文件 + show: 是否显示结果 + + Returns: + 推理结果列表 + """ + results = [] + + try: + # 检查目录是否存在 + config = get_conf(zip_url=input_dir, user_name=user_name, pwd=pwd) + scanner = get_scanner(zip_url=input_dir, user_name=user_name, pwd=pwd); + if not scanner.directory_exists(input_dir): + print(f"错误: {input_dir} 不是有效的目录") + return results + + # 获取所有图片文件 + image_files = scanner.get_smb_images(input_dir) + + if not image_files: + print(f"在目录 {input_dir} 中未找到图片文件") + return results + + print(f"找到 {len(image_files)} 个图片文件") + + # 处理每张图片 + for image_path in image_files: + result = self.process_single_image_share_dir( + image_path=image_path, + user_name=user_name, + pwd=pwd, + output_dir=output_dir, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + save_mask=save_mask, + save_label=save_label, + show=show, + result_save=result_save + ) + results.append(result) + + # 推送识别数据到共享目录 + pile_dict = get_pile_dict(image_path, user_name, pwd) + process_dir(pile_dict, output_dir) + # 找到 图像类 文件夹 + found_paths = scanner.find_folders_by_name( + share_path=config['share'], + folder_name='图像类' + ) + if len(found_paths) > 0 : + tmpConf = get_conf(found_paths[0], user_name, pwd) + scanner.upload_directory(output_dir, config['share'], remote_dir=tmpConf['dir']+"_识别") + else : + print(f"错误: 远程共享目录 找不到【图像类】目录") + + return results + + except PermissionError: + print(f"权限错误: 无法访问目录 {input_dir}") + return results + except Exception as e: + print(f"处理目录失败: {e}") + return results def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False): zip_save_path = "dataset/zip_file" @@ -596,8 +781,8 @@ def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0. input_path = zip_local_dir_save result_save = [] - conf_threshold = 0.25, - iou_threshold = 0.5, + conf_threshold = 0.25 + iou_threshold = 0.5 save_mask = True, save_label = True, show = True @@ -673,3 +858,85 @@ def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0. os.remove(zip_dir_path) return file_save_dir, "success" + + + + +def predict_images_share_dir(pt_name, zip_url, user_name, pwd, output_dir="predictions", conf_threshold=0.25, save_json=False): + # 本地测试模式 - 请根据实际情况修改以下路径 + # local_model_path = r"D:\project\verification\ultralytics-main\model\script\seg\pt\test.pttest.pt" + local_model_path = r"../pt_save/road_crack.pt" + local_output_dir = output_dir + # zip_url = "meta_data/ai_train_platform/train.zip" + + try: + # 加载模型 + print(f"正在加载模型: {local_model_path}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"模型已加载到 {device}") + except Exception as e: + print(f"处理目录失败: {e}") + + inference = YOLOSegmentationInference( + model_path=local_model_path, + device=device + ) + + # 加载模型 + if not inference.load_model(): + return + + + + # zip_url = r"D:\project\verification\ultralytics-main\model\script\seg\test_seg_pic" + + result_save = [] + conf_threshold = 0.25 + iou_threshold = 0.5 + save_mask = False + save_label = True + show = True + + # 查找指定文件夹 图像类 + config = get_conf(zip_url, user_name, pwd) + scanner = get_scanner(zip_url, user_name, pwd) + found_paths = scanner.find_folders_by_name( + share_path=config['share'], + folder_name='图像类' + ) + + target_path = "" + report_data_path = "" + if len(found_paths) > 0: + # 处理目录 + report_data_path = found_paths[0] + tmpConfig = get_conf(report_data_path, user_name, pwd) + found_paths = scanner.find_folders_by_name( + share_path=config['share'], + folder_name='Images', + start_dir=tmpConfig['dir'] + ) + if len(found_paths) > 0: + target_path = found_paths[0] + # inference.process_image_directory_share_dir( + # input_dir=target_path, + # user_name=user_name, + # pwd=pwd, + # output_dir=output_dir, + # conf_threshold=conf_threshold, + # iou_threshold=iou_threshold, + # save_mask=save_mask, + # save_label=save_label, + # show=show, + # result_save=result_save + # ) + + # 创建并启动线程 + thread1 = threading.Thread(target=inference.process_image_directory_share_dir, args=(target_path,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save)) + + # 启动线程 + thread1.start() + else: + print(f"错误: 输入 {zip_url} 不是有效的文件或目录") + + return f"{report_data_path}_识别", "success" \ No newline at end of file diff --git a/util/smb.py b/util/smb.py new file mode 100644 index 0000000..a47ce26 --- /dev/null +++ b/util/smb.py @@ -0,0 +1,859 @@ +import os +from smbclient import ( + register_session, + listdir, + scandir, + stat, + makedirs, # 递归创建目录 + open_file +) +from datetime import datetime +import numpy as np +import cv2 +import pandas as pd +import io + +class SMBScanner: + def __init__(self, ip, username, password, domain=''): + self.ip = ip + self.username = username + self.password = password + self.domain = domain + + def connect(self): + """连接 SMB 共享""" + try: + register_session( + self.ip, + username=self.username, + password=self.password + ) + print(f"成功连接到 {self.ip}") + return True + except Exception as e: + print(f"连接失败: {e}") + return False + + def directory_exists(self, full_path): + """ + 检查目录是否存在 + + Args: + full_path: 全路径 + + Returns: + bool: 目录是否存在 + """ + if not self.connect(): + return False + + try: + # 尝试获取目录信息 + dir_stat = stat(full_path) + return True + except Exception as e: + print(f"未知错误: {e}") + return False + + def read_excel(self, smb_path, sheet_name=0): + """读取Excel文件""" + if not self.connect(): + return False + + try: + with open_file(smb_path, mode='rb') as smb_file: + file_content = smb_file.read() + + excel_data = io.BytesIO(file_content) + df = pd.read_excel(excel_data, sheet_name=sheet_name) + return df + + except Exception as e: + print(f"读取Excel失败: {e}") + return None + + def process_all_rows(self, df): + """ + 处理所有行数据 + """ + if df is None or df.empty: + print("没有数据可处理") + return + + print("开始处理每行数据:") + print("=" * 60) + + results = [] + + for row_number, (index, row) in enumerate(df.iterrows(), 1): + print(f"\n处理第 {row_number} 行:") + print("-" * 40) + + # 显示行数据 + for col_name in df.columns: + value = row[col_name] + print(f" {col_name}: {value}") + + # 处理逻辑(根据实际需求修改) + processed_row = { + 'row_number': row_number, + 'original_index': index, + 'data': row.to_dict(), + 'summary': f"处理了 {len(df.columns)} 个字段" + } + + results.append(processed_row) + + # 进度显示 + if row_number % 10 == 0 or row_number == len(df): + print(f"\n 进度: {row_number}/{len(df)} ({row_number/len(df)*100:.1f}%)") + + print("\n" + "=" * 60) + print(f"处理完成!共处理 {len(results)} 行数据") + + return results + + def get_smb_images(self, full_path): + """SMB 图片文件获取""" + image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] + image_files = [] + + try: + for entry in scandir(full_path): + if entry.is_file(): + _, ext = os.path.splitext(entry.name) + if ext.lower() in image_extensions: + image_files.append(entry.path) + elif entry.is_dir(): + imgs = self.get_smb_images(entry.path) + image_files.extend(imgs) + except Exception as e: + print(f"错误: {e}") + + return image_files + + def build_full_path(self, share_path, file_path): + """构建完整的 SMB 路径""" + # 清理路径中的多余斜杠 + share_path = share_path.strip('\\') + file_path = file_path.lstrip('\\') + return f"\\\\{self.ip}\\{share_path}\\{file_path}" + + def read_txt_by_line(self, full_path): + """逐行读取,适合大文件""" + if not self.connect(): + return None + + print(f"读取 TXT 文件: {full_path}") + + try: + with open_file(full_path, mode='rb') as file_obj: + content_bytes = file_obj.read() + + # 使用 StringIO 逐行处理 + text_content = content_bytes.decode('utf-8', errors='ignore') + string_io = io.StringIO(text_content) + + lines = [] + line_number = 0 + + while True: + line = string_io.readline() + if not line: # 读到文件末尾 + break + + line_number += 1 + line = line.strip() + # print(f"行 {line_number}: {line}") + lines.append(line) + + print(f"总共读取 {line_number} 行") + return lines + + except Exception as e: + print(f"读取文件时出错: {e}") + return None + + def read_img_file(self, full_path): + """读取文件并返回 OpenCV 图像""" + if not self.connect(): + return None + + print(f"读取文件: {full_path}") + + file_obj = None + try: + # 以二进制模式读取文件 + file_obj = open_file(full_path, mode='rb') + content = b"" + + # 分块读取文件内容 + while True: + chunk = file_obj.read(8192) # 8KB 块 + if not chunk: + break + content += chunk + + print(f"成功读取 {len(content)} 字节") + + # 解码图像 + if len(content) == 0: + print("文件为空") + return None + + image_array = np.frombuffer(content, np.uint8) + image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + + if image is None: + print("图像解码失败 - 可能不是有效的图像文件") + return None + + print(f"图像解码成功: {image.shape}") + return image + + except Exception as e: + print(f"读取文件失败: {e}") + return None + finally: + if file_obj: + file_obj.close() + + def writeFile(self, share_path, file_path, data, chunk_size=8192): + """写入文件到 SMB 共享""" + if not self.connect(): + return False + + full_path = self.build_full_path(share_path, file_path) + file_obj = None + + try: + # 确保目录存在 + dir_path = os.path.dirname(full_path) + try: + makedirs(dir_path, exist_ok=True) + except: + pass # 目录可能已存在 + + file_obj = open_file(full_path, mode='wb') + + if isinstance(data, bytes): + total_size = len(data) + written = 0 + + for i in range(0, total_size, chunk_size): + chunk = data[i:i + chunk_size] + file_obj.write(chunk) + written += len(chunk) + print(f"写入进度: {written}/{total_size} 字节 ({written/total_size*100:.1f}%)") + + elif hasattr(data, '__iter__'): + total_written = 0 + for chunk in data: + if isinstance(chunk, str): + chunk = chunk.encode('utf-8') + file_obj.write(chunk) + total_written += len(chunk) + print(f"已写入: {total_written} 字节") + else: + file_obj.write(bytes(data)) + + print(f"文件写入完成: {full_path}") + return True + + except Exception as e: + print(f"写入文件失败: {e}") + return False + finally: + if file_obj: + file_obj.close() + + def writeImageToFile(self, share_path, file_path, image, image_format='.jpg', quality=95): + """将 OpenCV 图像写入 SMB 文件""" + if not self.connect(): + return False + + full_path = f"{file_path}{image_format}" + file_obj = None + + try: + if image_format.lower() == '.jpg': + encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] + success, encoded_image = cv2.imencode(image_format, image, encode_params) + else: + success, encoded_image = cv2.imencode(image_format, image) + + if not success: + print("图像编码失败") + return False + + image_data = encoded_image.tobytes() + return self.writeFile(share_path, f"{file_path}{image_format}", image_data) + + except Exception as e: + print(f"写入图像失败: {e}") + return False + + def _ensure_remote_directory(self, share_name, remote_dir): + """确保远程目录存在""" + if not remote_dir: + return + + try: + # 构建完整远程路径 + full_remote_path = self.build_full_path(share_name, remote_dir) + + # 使用 makedirs 递归创建目录(如果不存在) + makedirs(full_remote_path, exist_ok=True) + print(f"确保远程目录存在: {remote_dir}") + + except Exception as e: + print(f"创建远程目录失败: {e}") + raise + + def upload_directory(self, local_dir, share_name, remote_dir="", overwrite=True): + """ + 将本地目录推送到远程共享目录 + """ + if not self.connect(): + return False + + print(f"开始上传目录: {local_dir} -> {share_name}/{remote_dir}") + + if not os.path.exists(local_dir): + print(f"本地目录不存在: {local_dir}") + return False + + try: + # 确保远程目录存在 + self._ensure_remote_directory(share_name, remote_dir) + + # 递归上传目录内容 + success = self._upload_directory_recursive(local_dir, share_name, remote_dir, overwrite) + + if success: + print("目录上传完成") + else: + print("目录上传过程中出现错误") + + return success + + except Exception as e: + print(f"上传目录失败: {e}") + return False + + def _upload_directory_recursive(self, local_path, share_name, remote_path, overwrite): + """递归上传目录内容""" + try: + success = True + + for item_name in os.listdir(local_path): + local_item_path = os.path.join(local_path, item_name) + remote_item_path = f"{remote_path}/{item_name}" if remote_path else item_name + + if os.path.isdir(local_item_path): + # 处理子目录 + print(f"上传子目录: {item_name}") + + # 确保远程子目录存在 + self._ensure_remote_directory(share_name, remote_item_path) + + # 递归上传子目录 + sub_success = self._upload_directory_recursive(local_item_path, share_name, remote_item_path, overwrite) + if not sub_success: + success = False + + else: + # 上传文件 + file_success = self._upload_single_file(local_item_path, share_name, remote_item_path, overwrite) + if not file_success: + success = False + + return success + + except Exception as e: + print(f"上传目录内容失败 {local_path}: {e}") + return False + + def _upload_single_file(self, local_file_path, share_name, remote_file_path, overwrite): + """上传单个文件""" + file_obj = None + try: + # 构建远程完整路径 + full_remote_path = self.build_full_path(share_name, remote_file_path) + + # 检查文件是否已存在 + if not overwrite: + try: + stat(full_remote_path) + print(f"文件已存在,跳过: {remote_file_path}") + return True + except FileNotFoundError: + # 文件不存在,继续上传 + pass + + # 上传文件 + print(f"上传文件: {os.path.basename(local_file_path)}") + + # 读取本地文件 + with open(local_file_path, 'rb') as local_file: + local_content = local_file.read() + + # 写入远程文件 + with open_file(full_remote_path, mode='wb') as remote_file: + remote_file.write(local_content) + + file_size = len(local_content) + print(f"文件上传成功: {remote_file_path} ({file_size} 字节)") + return True + + except Exception as e: + print(f"上传文件失败 {local_file_path}: {e}") + return False + + def upload_file(self, local_file_path, share_name, remote_file_path, overwrite=True): + """ + 上传单个文件到远程共享目录 + """ + if not self.connect(): + return False + + print(f"上传文件: {local_file_path} -> {share_name}/{remote_file_path}") + + file_obj = None + try: + # 构建远程完整路径 + full_remote_path = self.build_full_path(share_name, remote_file_path) + + # 检查文件是否已存在 + if not overwrite: + try: + stat(full_remote_path) + print(f"文件已存在,跳过: {remote_file_path}") + return True + except FileNotFoundError: + # 文件不存在,继续上传 + pass + + # 以二进制模式读取本地文件 + with open(local_file_path, 'rb') as local_file: + content = b"" + + # 分块读取文件内容 + while True: + chunk = local_file.read(8192) # 8KB 块 + if not chunk: + break + content += chunk + + print(f"成功读取 {len(content)} 字节") + + if len(content) == 0: + print("文件为空") + return False + + # 写入远程文件 + with open_file(full_remote_path, mode='wb') as remote_file: + remote_file.write(content) + + print(f"文件上传成功") + return True + + except Exception as e: + print(f"上传文件失败: {e}") + return False + + def find_folders_by_name(self, share_path, folder_name, start_dir="", max_depth=10): + """专门查找文件夹""" + return self.find_items_by_name( + share_path=share_path, + target_name=folder_name, + item_type="folder", + start_dir=start_dir, + max_depth=max_depth + ) + + def find_files_by_name(self, share_path, file_name, start_dir="", max_depth=10): + """专门查找文件""" + return self.find_items_by_name( + share_path=share_path, + target_name=file_name, + item_type="file", + start_dir=start_dir, + max_depth=max_depth + ) + + def find_items_by_name(self, share_path, target_name, item_type="both", start_dir="", max_depth=10): + """ + 递归查找指定名称的文件夹和/或文件 + + Args: + share_path: 共享名称 + target_name: 目标名称(支持通配符 * 和 ?) + item_type: 查找类型 - "folder", "file", "both" + start_dir: 起始目录 + max_depth: 最大搜索深度 + + Returns: + list: 找到的完整路径列表 + """ + if not self.connect(): + return [] + + found_paths = [] + start_path = self.build_full_path(share_path, start_dir) + + try: + self._search_recursive( + share_path=share_path, + current_path=start_path, + target_name=target_name, + item_type=item_type, + found_paths=found_paths, + current_depth=0, + max_depth=max_depth + ) + except Exception as e: + print(f"搜索过程中出错: {e}") + + return found_paths + + def _search_recursive(self, share_path, current_path, target_name, item_type, found_paths, current_depth, max_depth): + """递归搜索文件夹和文件""" + if current_depth > max_depth: + return + + try: + for entry in scandir(current_path): + try: + # 检查文件夹 + if entry.is_dir(): + if self._is_match(entry.name, target_name) and item_type in ["both", "folder"]: + found_paths.append(entry.path) + print(f"找到目标文件夹: {entry.path}") + + # 递归搜索子目录 + self._search_recursive( + share_path=share_path, + current_path=entry.path, + target_name=target_name, + item_type=item_type, + found_paths=found_paths, + current_depth=current_depth + 1, + max_depth=max_depth + ) + + # 检查文件 + elif entry.is_file(): + if self._is_match(entry.name, target_name) and item_type in ["both", "file"]: + found_paths.append(entry.path) + print(f"找到目标文件: {entry.path}") + + except Exception as e: + print(f"处理条目 {entry.path} 时出错: {e}") + + except Exception as e: + print(f"搜索目录 {current_path} 时出错: {e}") + + def _is_match(self, name, pattern): + """ + 检查名称是否匹配模式(支持简单通配符) + + Args: + name: 实际名称 + pattern: 匹配模式(支持 * 和 ?) + + Returns: + bool: 是否匹配 + """ + # 如果没有通配符,直接比较 + if '*' not in pattern and '?' not in pattern: + return name.lower() == pattern.lower() + + # 通配符匹配 + import fnmatch + return fnmatch.fnmatch(name.lower(), pattern.lower()) + + def list_directory(self, share_path, dir, recursive=False, max_depth=3): + """列出目录内容""" + if not self.connect(): + return [] + + try: + full_path = f"\\\\{self.ip}\\{share_path}\\{dir}" + print(f"开始遍历: {full_path}") + result = [] + self._walk_directory(full_path, recursive, max_depth, 0, result) + + except Exception as e: + print(f"遍历失败: {e}") + + return result + + def _walk_directory(self, path, recursive, max_depth, current_depth, result): + """递归遍历目录""" + if current_depth > max_depth: + return + + try: + for entry in scandir(path): + try: + file_stat = stat(entry.path) + indent = " " * current_depth + # 创建条目信息字典 + item = { + 'name': entry.name, + 'path': entry.path, + 'depth': current_depth, + 'indent': indent, + 'is_dir': entry.is_dir(), + 'size': file_stat.st_size if not entry.is_dir() else 0, + 'modified_time': datetime.fromtimestamp(file_stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S') + } + + if entry.is_dir(): + # print(f"{indent}文件夹:{entry.name}/") + result.append(item) + if recursive and current_depth < max_depth: + sub_items = self._walk_directory( + entry.path, + recursive, + max_depth, + current_depth + 1 + ) + result.extend(sub_items) + else: + file_size = self._format_size(file_stat.st_size) + mod_time = datetime.fromtimestamp( + file_stat.st_mtime + ).strftime('%Y-%m-%d %H:%M:%S') + # print(f"{indent}文件:{entry.name} [{file_size}] [{mod_time}]") + item['formatted_size'] = file_size + result.append(item) + + except Exception as e: + print(f"{indent} 无法访问: {entry.name} - {e}") + + except Exception as e: + print(f"无法读取目录 {path}: {e}") + + return result + + def _format_size(self, size_bytes): + """格式化文件大小""" + if size_bytes == 0: + return "0 B" + + size_names = ["B", "KB", "MB", "GB", "TB"] + i = 0 + while size_bytes >= 1024 and i < len(size_names) - 1: + size_bytes /= 1024.0 + i += 1 + + return f"{size_bytes:.1f} {size_names[i]}" + + def get_file_info(self, share_path, file_path): + """获取文件详细信息""" + if not self.connect(): + return None + + try: + full_path = f"\\\\{self.ip}\\{share_path}\\{file_path}" + file_stat = stat(full_path) + + return { + 'name': os.path.basename(file_path), + 'path': full_path, + 'size': file_stat.st_size, + 'size_formatted': self._format_size(file_stat.st_size), + 'create_time': datetime.fromtimestamp(file_stat.st_ctime), + 'modify_time': datetime.fromtimestamp(file_stat.st_mtime), + 'access_time': datetime.fromtimestamp(file_stat.st_atime), + 'is_dir': False # 需要额外判断 + } + except Exception as e: + print(f"获取文件信息失败: {e}") + return None + + def display_image(self, image, window_name="Image"): + """ + 显示图像 + + Args: + image: OpenCV图像 + window_name: 窗口名称 + """ + # 创建窗口 + cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) + + # 调整窗口大小适应屏幕 + screen_width = 1920 # 可根据实际屏幕调整 + screen_height = 1080 + + img_height, img_width = image.shape[:2] + + # 计算缩放比例 + scale = min(screen_width / img_width, screen_height / img_height, 1.0) + + if scale < 1.0: + new_width = int(img_width * scale) + new_height = int(img_height * scale) + image = cv2.resize(image, (new_width, new_height)) + print(f"图像已缩放: {img_width}x{img_height} -> {new_width}x{new_height}") + + # 显示图像 + cv2.imshow(window_name, image) + print("图像显示中... 按任意键关闭窗口") + + # 等待按键 + cv2.waitKey(0) + cv2.destroyAllWindows() + print("窗口已关闭") + +# 从传入的路径中提取ip,共享目录,目标访问目录 +def get_conf(zip_url, user_name, pwd) : + zip_url = zip_url.replace('\\\\', '/') + zip_url = zip_url.replace('\\', '/') + if zip_url.startswith("/"): + zip_url = zip_url.replace('/', '', 1) + + parts = zip_url.split('/') + if len(parts) < 2 : + print(f"传入的共享目录格式错误: {zip_url}") + return "", "fail" + + dir = '' + if len(parts) > 2: + new_parts = parts[2:] + dir = '/'.join(new_parts) + + # 配置信息 + config = { + 'ip': parts[0], + 'username': user_name, + 'password': pwd, + 'domain': '', # 工作组留空 + 'share': parts[1], + 'dir': dir + } + + return config + + +def get_scanner(zip_url, user_name, pwd) : + config = get_conf(zip_url, user_name, pwd) + + # 创建扫描器 + scanner = SMBScanner( + ip=config['ip'], + username=config['username'], + password=config['password'], + domain=config['domain'] + ) + return scanner + +# filename -> 桩号 +def get_pile_dict(dir,user_name,pwd) : + config = get_conf(dir, user_name, pwd) + scanner = get_scanner(dir, user_name=user_name, pwd=pwd) + found_paths = scanner.find_files_by_name( + share_path=config['share'], + file_name='fileindex.txt' + ) + print(f"\n找到 {len(found_paths)} 个 'fileindex.txt' 文件:") + for i, path in enumerate(found_paths, 1): + print(f"{i}. {path}") + + lines = scanner.read_txt_by_line(full_path=found_paths[0]) + + pile_dict = {} + for i, line in enumerate(lines, 1): + parts = line.strip().split("->") + if len(parts)>=4: + pile_dict[parts[3]]=parts[1] # filename -> 桩号 + + return pile_dict + +def main(): + # 配置信息 + config = { + 'ip': '192.168.110.114', + 'username': 'administrator', + 'password': 'abc@1234', + 'domain': '', # 工作组留空 + 'share': 'share_File', + 'dir': '西南计算机' + } + + # 创建扫描器 + scanner = SMBScanner( + ip=config['ip'], + username=config['username'], + password=config['password'], + domain=config['domain'] + ) + + # 遍历共享目录 + # scanner.list_directory( + # share_path=config['share'], + # dir=config['dir'], + # recursive=True, # 递归遍历 + # max_depth=9 # 最大深度 + # ) + + # 读取文件 + # full_path = scanner.build_full_path( + # share_path=config['share'], + # file_path= f"{config['dir']}/AA县/报送数据/图像类/CD45500155A/Images/20250508131651/01/20250508-131712-644.jpg" + # ) + # image = scanner.read_img_file(full_path=full_path) + + # scanner.display_image(image) + + # # 写入文件 + # scanner.writeImageToFile( + # share_path=config['share'], + # file_path= f"{config['dir']}/AA县/报送数据/图像类_识别/CD45500155A/Images/20250508131651/01/20250508-131712-644.jpg", + # image=image + # ) + + # # 查找指定文件夹 报送数据 + # found_paths = scanner.find_folders_by_name( + # share_path=config['share'], + # folder_name='报送数据' + # ) + # print(f"\n找到 {len(found_paths)} 个 '报送数据' 文件夹:") + # for i, path in enumerate(found_paths, 1): + # print(f"{i}. {path}") + + + # # 查找指定目录中的所有图片 + # full_path = scanner.build_full_path(share_path=config['share'], file_path='西南计算机\\AA县\\报送数据') + # imgPaths = scanner.get_smb_images(full_path) + # for i, path in enumerate(imgPaths, 1): + # print(f"{i}. {path}") + + # # 读取excel + # full_path = scanner.build_full_path(share_path=config['share'], file_path='西南计算机\\AA县\\24年年报.xlsx') + # df = scanner.read_excel(full_path) + # scanner.process_all_rows(df) + + + # 读取txt + # found_paths = scanner.find_files_by_name( + # share_path=config['share'], + # file_name='fileindex.txt' + # ) + # print(f"\n找到 {len(found_paths)} 个 'fileindex.txt' 文件:") + # for i, path in enumerate(found_paths, 1): + # print(f"{i}. {path}") + + # lines = scanner.read_txt_by_line(full_path=found_paths[0]) + # for i, line in enumerate(lines, 1): + # print(f"{i}. {line}") + + output_dir = "D:/devForBdzlWork/ai-train_platform/predictions" + scanner.upload_directory(output_dir, config['share'], remote_dir="西南计算机/AA县/报送数据_识别") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/util/yolo2pix_new.py b/util/yolo2pix_new.py new file mode 100644 index 0000000..820b5e5 --- /dev/null +++ b/util/yolo2pix_new.py @@ -0,0 +1,243 @@ +import os +import zipfile +import shutil +import cv2 +import numpy as np +from collections import defaultdict +import smb + +# ---------------- 常量 ---------------- +CELL_AREA = 0.0036 # 每格面积 (平方米) +GRID_WIDTH = 108 # 网格像素宽 +GRID_HEIGHT = 102 # 网格像素高 +COVER_RATIO = 0.01 # mask 覆盖比例阈值 + +# ---------------- 路面类别映射 ---------------- +CLASS_MAP_ASPHALT = { + "龟裂":0,"块状裂缝":1,"纵向裂缝":2,"横向裂缝":3,"沉陷":4,"车辙":5,"波浪拥包":6,"坑槽":7,"松散":8,"泛油":9,"修补":10 +} +CLASS_MAP_CEMENT = { + "破碎板":0,"裂缝":1,"板角断裂":2,"错台":3,"拱起":4,"边角剥落":5,"接缝料损坏":6,"坑洞":7,"唧泥":8,"露骨":9,"修补":10 +} +CLASS_MAP_GRAVEL = { + "坑槽":0,"沉陷":1,"车辙":2,"波浪搓板":3 +} + +# ---------------- 工具函数 ---------------- +def num_to_coord(num, cols, cell_w, cell_h): + n = num - 1 + r, c = divmod(n, cols) + x1, y1 = c * cell_w, r * cell_h + x2, y2 = x1 + cell_w, y1 + cell_h + return x1, y1, x2, y2 + +def draw_grid_on_image(image_path, grid_cells, cell_size=(GRID_WIDTH, GRID_HEIGHT), save_path=None): + image = cv2.imread(image_path) + if image is None: return + h, w = image.shape[:2] + cell_w, cell_h = cell_size + cols = w // cell_w + overlay = image.copy() + for cname, nums in grid_cells.items(): + color = (np.random.randint(64,255),np.random.randint(64,255),np.random.randint(64,255)) + for num in nums: + x1,y1,x2,y2 = num_to_coord(num, cols, cell_w, cell_h) + cv2.rectangle(overlay,(x1,y1),(x2,y2),color,-1) + cv2.addWeighted(overlay,0.4,image,0.6,0,image) + for i in range(0, w, cell_w): + cv2.line(image,(i,0),(i,h),(100,100,100),1) + for j in range(0, h, cell_h): + cv2.line(image,(0,j),(w,j),(100,100,100),1) + if save_path: cv2.imwrite(save_path,image) + return image + +def detect_road_type_from_content(label_file): + """根据标签内容判断路面类型""" + try: + with open(label_file,'r',encoding='utf-8') as f: + content = f.read() + except: + return "gravel" + for kw in CLASS_MAP_ASPHALT.keys(): + if kw in content: return "asphalt" + for kw in CLASS_MAP_CEMENT.keys(): + if kw in content: return "cement" + for kw in CLASS_MAP_GRAVEL.keys(): + if kw in content: return "gravel" + return "gravel" + +def yoloseg_to_grid(image_path,label_file,cover_ratio=COVER_RATIO): + """将YOLO-Seg标签转换成格子编号和类别""" + road_type = detect_road_type_from_content(label_file) + if road_type=="asphalt": class_map = CLASS_MAP_ASPHALT + elif road_type=="cement": class_map = CLASS_MAP_CEMENT + else: class_map = CLASS_MAP_GRAVEL + class_names = list(class_map.keys()) + + img = cv2.imread(image_path) + if img is None: return "", {} + h, w = img.shape[:2] + cols = max(1, w//GRID_WIDTH) + rows = max(1, h//GRID_HEIGHT) + + result_lines = [] + all_class_cells = {} + with open(label_file,'r',encoding='utf-8') as f: + for line in f: + parts = line.strip().split() + if len(parts)<5: continue + cls_id = int(parts[0]) + coords = [float(x) for x in parts[1:]] + if len(coords)%2!=0: coords=coords[:-1] + if len(coords)<6: continue + poly = np.array(coords,dtype=np.float32).reshape(-1,2) + poly[:,0]*=w + poly[:,1]*=h + mask = np.zeros((h,w),dtype=np.uint8) + cv2.fillPoly(mask,[poly.astype(np.int32)],255) + covered_cells=[] + for r in range(rows): + for c in range(cols): + x1,y1 = c*GRID_WIDTH, r*GRID_HEIGHT + x2,y2 = min(w,x1+GRID_WIDTH), min(h,y1+GRID_HEIGHT) + region = mask[y1:y2, x1:x2] + if np.count_nonzero(region)/region.size>cover_ratio: + covered_cells.append(r*cols+c+1) + if not covered_cells: continue + cname = class_names[cls_id] if cls_id0 else 1) *100 # 简化为100%或者0 + summary_data.append((pile_no, DR, counts, road_type)) + + # 写桩号问题列表.txt + if summary_data: + road_type = summary_data[0][3] + out_file = os.path.join(dir,"桩号问题列表.txt") + header = generate_header(road_type) + with open(out_file,'w',encoding='utf-8') as f: + f.write(header+'\n') + for pile_no,DR,counts,rt in summary_data: + row = [pile_no,"3.6",f"{DR:.2f}"] + if road_type=="asphalt": + keys = list(CLASS_MAP_ASPHALT.keys()) + elif road_type=="cement": + keys = list(CLASS_MAP_CEMENT.keys()) + else: + keys = list(CLASS_MAP_GRAVEL.keys()) + for k in keys: + row.append(f"{counts.get(k,0):.2f}") + f.write(','.join(row)+'\n') + print(f"✅ 输出完成: {out_file}") + +# ---------------- 主函数 ---------------- +def process_zip(zip_path,pile_map_file,output_dir="output",cell_area=CELL_AREA,grid_width=GRID_WIDTH,grid_height=GRID_HEIGHT): + if not os.path.exists(zip_path): + raise FileNotFoundError(f"{zip_path} 不存在") + os.makedirs(output_dir,exist_ok=True) + # 解压 + with zipfile.ZipFile(zip_path,'r') as zip_ref: + zip_ref.extractall(output_dir) + + # 读取桩号映射 + pile_dict = {} + with open(pile_map_file,'r',encoding='utf-8') as f: + for line in f: + parts = line.strip().split("->") + if len(parts)>=4: + pile_dict[parts[3]]=parts[1] # filename -> 桩号 + + # 遍历图片 + summary_data = [] + for root,_,files in os.walk(output_dir): + for file in files: + if file.lower().endswith((".jpg",".png",".jpeg",".bmp")) : + image_path = os.path.join(root,file) + label_file = os.path.splitext(image_path)[0]+".txt" + if not os.path.exists(label_file): + print(f"⚠️ 找不到标签: {label_file}") + continue + out_txt, class_cells, road_type = yoloseg_to_grid(image_path,label_file) + # 写每张图独立 _grid.txt + grid_txt_path = os.path.splitext(image_path)[0]+"_grid.txt" + with open(grid_txt_path,'w',encoding='utf-8') as f: + f.write(out_txt) + # 生成网格可视化 + draw_grid_on_image(image_path,class_cells,save_path=os.path.splitext(image_path)[0]+"_grid.jpg") + # 统计各类面积 + counts = {k:len(v)*cell_area for k,v in class_cells.items()} + total_area = sum(counts.values()) + # 桩号 + pile_no = pile_dict.get(file,"未知") + # 破损率 DR (%) = total_area / 总面积 + DR = total_area/ (total_area if total_area>0 else 1) *100 # 简化为100%或者0 + summary_data.append((pile_no, DR, counts, road_type)) + + # 写桩号问题列表.txt + if summary_data: + road_type = summary_data[0][3] + out_file = os.path.join(output_dir,"桩号问题列表.txt") + header = generate_header(road_type) + with open(out_file,'w',encoding='utf-8') as f: + f.write(header+'\n') + for pile_no,DR,counts,rt in summary_data: + row = [pile_no,"3.6",f"{DR:.2f}"] + if road_type=="asphalt": + keys = list(CLASS_MAP_ASPHALT.keys()) + elif road_type=="cement": + keys = list(CLASS_MAP_CEMENT.keys()) + else: + keys = list(CLASS_MAP_GRAVEL.keys()) + for k in keys: + row.append(f"{counts.get(k,0):.2f}") + f.write(','.join(row)+'\n') + print(f"✅ 输出完成: {out_file}") + +# ---------------- 示例调用 ---------------- +if __name__=="__main__": + # zip_path = "D:/devForBdzlWork/ai-train_platform/predict/inferenceResult.zip" # 输入 ZIP 文件 + # pile_map_file = "D:/devForBdzlWork/ai-train_platform/predict/pile_map.txt" # 图片名 -> 桩号 + # process_zip(zip_path=zip_path,pile_map_file=pile_map_file,output_dir="output") + + output_dir = "D:/devForBdzlWork/ai-train_platform/predictions/1" + pile_dict = smb.get_pile_dict("192.168.110.114/share_File/西南计算机", "administrator", "abc@1234") + process_dir(pile_dict, output_dir) diff --git a/yolo_api.py b/yolo_api.py index d9ff60c..77e0ab6 100644 --- a/yolo_api.py +++ b/yolo_api.py @@ -178,6 +178,7 @@ from pathlib import Path from download_train import download_train from predict.predict_yolo11seg import predict_images +from predict.predict_yolo11seg import predict_images_share_dir from query_process_status import get_process_status @@ -487,6 +488,52 @@ async def query_train_task(request: Request): +# 接收前端实时流,进行任务推理-共享目录 +@app.post("/ai/project/inference4ShareDir") +async def start_inference_share_dir(request): + try: + # 解析并验证请求数据 + request_json = request.json + task_id = request_json["task_id"] + pt_name = request_json["pt_name"] + zip_url = request_json["zip_url"] + user_name = request_json["user_name"] + pwd = request_json["pwd"] + time_ns = time.time_ns() + # pt_name = f"{time_ns}-{task_id}.pt" + # model_path=r"pt_save\best.pt" + print(f"task_id {task_id}") + + if user_name == "": + user_name = "administrator" + if pwd == "": + pwd = "abc@1234" + + output_dir = f"predictions/{task_id}" + inference_zip_url,message=predict_images_share_dir(pt_name, zip_url, user_name, pwd, output_dir=output_dir, conf_threshold=0.25, save_json=False) + if inference_zip_url: + return response.json({ + "status": "success", + "task_id": task_id, + "inference_zip_url":inference_zip_url, + "message": "predict request successfully" + }) + else: + return response.json({ + "status": "fail", + "task_id": task_id, + "inference_zip_url":inference_zip_url, + "message": message + }) + except ValueError as e: + print(f"Validation error: {str(e)}") + return response.json({"status": "error", "message": str(e)}, status=400) + except Exception as e: + print(f"Unexpected error: {str(e)}") + return response.json({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) + + + # 接收前端实时流,进行任务推理 @app.post("/ai/project/inference") async def start_inference(request):