diff --git a/middleware/recognition_task.py b/middleware/recognition_task.py new file mode 100644 index 0000000..afce276 --- /dev/null +++ b/middleware/recognition_task.py @@ -0,0 +1,109 @@ +import psycopg2 +from psycopg2.extras import RealDictCursor +import json +from typing import Dict, List, Union, Optional +from dataclasses import dataclass, asdict +from datetime import datetime +import re + + +@dataclass +class RecognitionTask: + """识别任务""" + + task_name: Optional[str] = None + model_id: Optional[int] = None + model_name: Optional[str] = None + model_version_id: Optional[str] = None + created_by: Optional[str] = None + status: Optional[int] = None + + id: Optional[int] = None + exec_msg: Optional[str] = None + created_at: Optional[datetime] = None + result_url: Optional[str] = None + source_url: Optional[str] = None + task_id: Optional[str] = None + resource_record_id: Optional[int] = None + +class RecognitionTaskDAO: + def __init__(self, db_params: Dict[str, str]): + """ + 初始化数据库连接 + 参数: + db_params: 数据库连接参数,包含: + - dbname: 数据库名 + - user: 用户名 + - password: 密码 + - host: 主机地址 + - port: 端口号 + """ + self.db_params = db_params + + def update_recognition_task(self, model: RecognitionTask) -> bool: + """ + 更新现有的识别任务 + + 参数: + model: 要更新的识别任务对象 + + 返回: + 是否更新成功 + """ + if not isinstance(model, RecognitionTask): + raise ValueError("Invalid configuration type") + + data = self._to_db_format(model) + + query = """ + UPDATE bz_recognition_tasks SET + status = %(status)s + WHERE task_id = %(task_id)s + """ + + try: + with psycopg2.connect(**self.db_params) as conn: + with conn.cursor() as cur: + cur.execute(query, data) + conn.commit() + return True + except psycopg2.Error as e: + print(f"Database update error: {e}") + return False + + def _to_db_format(self, task: RecognitionTask) -> Dict: + """将RecognitionTask对象转换为数据库格式""" + return { + "id": task.id, + "task_name": task.task_name, + "model_id": task.model_id, + "model_name": task.model_name, + "model_version_id": task.model_version_id, + "status": task.status, + "exec_msg": task.exec_msg, + "created_at": task.created_at, + "created_by": task.created_by, + "result_url": task.result_url, + "source_url": task.source_url, + "task_id": task.task_id, + "resource_record_id": task.resource_record_id + } + + + def _from_db_format(self, db_data: Dict) -> RecognitionTask: + """从数据库格式转换为RecognitionTask对象""" + return RecognitionTask( + id=db_data.get("id"), + task_name=db_data.get("task_name", ""), + model_id=db_data.get("model_id", 0), + model_name=db_data.get("model_name", ""), + model_version_id=db_data.get("model_version_id", ""), + status=db_data.get("status", 0), + exec_msg=db_data.get("exec_msg"), + created_at=db_data.get("created_at"), + created_by=db_data.get("created_by", ""), + result_url=db_data.get("result_url"), + source_url=db_data.get("source_url"), + task_id=db_data.get("task_id"), + resource_record_id=db_data.get("resource_record_id") + ) \ No newline at end of file diff --git a/predict/predict_yolo11seg.py b/predict/predict_yolo11seg.py index f7ef9bd..9cdeb50 100644 --- a/predict/predict_yolo11seg.py +++ b/predict/predict_yolo11seg.py @@ -4,8 +4,9 @@ import zipfile from os.path import exists import torch - +import gc import os +import psutil import cv2 import numpy as np import time @@ -106,6 +107,7 @@ class YOLOSegmentationInference: self.device = device self.model = None self.class_names = [] + self.clean_memory_per_count = 500 # 定义颜色映射(用于不同类别) self.colors = [ @@ -537,7 +539,7 @@ class YOLOSegmentationInference: def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = True, - result_save: [] = None) -> InferenceResult: + result_save: [] = None) -> None: """ 处理单张图片 @@ -571,7 +573,7 @@ class YOLOSegmentationInference: # if show: # self.show_results(result) - return result + del result def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, @@ -640,8 +642,7 @@ class YOLOSegmentationInference: def process_image_directory_share_dir_circle(self, task_id, current_time, input_dir_list, user_name, pwd, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = False, - result_save: [] = None) -> List[ - InferenceResult]: + result_save: [] = None) -> None: for input_dir in input_dir_list : self.process_image_directory_share_dir(task_id,current_time,input_dir,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save) del_file_shutil(output_dir) @@ -649,8 +650,7 @@ class YOLOSegmentationInference: def process_image_directory_share_dir(self, task_id, current_time, input_dir, user_name, pwd, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = False, - result_save: [] = None) -> List[ - InferenceResult]: + result_save: [] = None) -> None: """ 处理目录中的所有图片 @@ -666,7 +666,7 @@ class YOLOSegmentationInference: Returns: 推理结果列表 """ - results = [] + tmp_output_dir = output_dir + "\\" + datetime.now().strftime("%Y%m%d%H%M%S") print(f"正在处理共享目录: {input_dir} - {tmp_output_dir}") @@ -676,19 +676,18 @@ class YOLOSegmentationInference: scanner = get_scanner(zip_url=input_dir, user_name=user_name, pwd=pwd); if not scanner.directory_exists(input_dir): print(f"错误: {input_dir} 不是有效的目录") - return results # 获取所有图片文件 image_files = scanner.get_smb_images(input_dir) if not image_files: print(f"在目录 {input_dir} 中未找到图片文件") - return results print(f"找到 {len(image_files)} 个图片文件") # 处理每张图片 - for image_path in image_files: + for idx, image_path in enumerate(image_files): + # for image_path in image_files: result = self.process_single_image_share_dir( image_path=image_path, user_name=user_name, @@ -701,7 +700,21 @@ class YOLOSegmentationInference: show=show, result_save=result_save ) - results.append(result) + + if idx % self.clean_memory_per_count == 0 : + print(f"idx = {idx}---每处理{self.clean_memory_per_count}张图片清理内存") + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + mem_usage = memory_info.rss / 1024 / 1024 # 返回MB + print(f"清理前,内存使用: {mem_usage:.2f} MB") + + if self.device == 'cuda': + torch.cuda.empty_cache() + gc.collect() + + memory_info = process.memory_info() + mem_usage = memory_info.rss / 1024 / 1024 # 返回MB + print(f"清理后,内存使用: {mem_usage:.2f} MB") # 推送识别数据到共享目录 tmpConfig = get_conf(input_dir, user_name, pwd) @@ -728,14 +741,11 @@ class YOLOSegmentationInference: status=2 ) dao.update_recognition_task(recognition_task) - return results except PermissionError: print(f"权限错误: 无法访问目录 {input_dir}") - return results except Exception as e: print(f"处理目录失败: {e}") - return results def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False): zip_save_path = "dataset/zip_file" diff --git a/util/smb_tool.py b/util/smb_tool.py index 49f6a52..6a37bc4 100644 --- a/util/smb_tool.py +++ b/util/smb_tool.py @@ -74,15 +74,51 @@ class SMBScanner: print(f"读取Excel失败: {e}") return None - def process_all_rows(self, df): + def process_all_rows(self, df, columns=None): """ - 处理所有行数据 + 处理所有行数据,可选择指定列 + + Parameters: + ----------- + df : pandas.DataFrame + 要处理的数据框 + columns : list or None, optional + 要处理的列名列表,None表示处理所有列 + + Returns: + -------- + results : list + 处理结果列表,每个元素包含行信息和处理后的数据 """ if df is None or df.empty: print("没有数据可处理") - return + return [] - print("开始处理每行数据:") + # 验证和处理列参数 + if columns is None: + # 使用所有列 + selected_columns = df.columns.tolist() + else: + # 检查指定的列是否存在 + valid_columns = [] + invalid_columns = [] + + for col in columns: + if col in df.columns: + valid_columns.append(col) + else: + invalid_columns.append(col) + + if invalid_columns: + print(f"警告:以下列名不存在,将被忽略: {invalid_columns}") + + if not valid_columns: + print("错误:没有有效的列名可处理") + return [] + + selected_columns = valid_columns + + print(f"开始处理数据,选择列: {selected_columns}") print("=" * 60) results = [] @@ -91,17 +127,18 @@ class SMBScanner: # print(f"\n处理第 {row_number} 行:") # print("-" * 40) - # 显示行数据 - for col_name in df.columns: + # 只显示选定的列数据 + selected_data = {} + for col_name in selected_columns: value = row[col_name] + selected_data[col_name] = value # print(f" {col_name}: {value}") # 处理逻辑(根据实际需求修改) processed_row = { 'row_number': row_number, 'original_index': index, - 'data': row.to_dict(), - 'summary': f"处理了 {len(df.columns)} 个字段" + 'data': selected_data # 只包含选定列的数据 } results.append(processed_row) @@ -111,7 +148,7 @@ class SMBScanner: print(f"\n 进度: {row_number}/{len(df)} ({row_number/len(df)*100:.1f}%)") # print("\n" + "=" * 60) - print(f"处理完成!共处理 {len(results)} 行数据") + print(f"处理完成!共处理 {len(results)} 行数据,{len(selected_columns)} 个字段") return results @@ -767,6 +804,9 @@ def get_scanner(zip_url, user_name, pwd) : def get_road_dict(dir,user_name,pwd) : config = get_conf(dir, user_name, pwd) scanner = get_scanner(dir, user_name=user_name, pwd=pwd) + + road_dict_for_width = get_road_dict_for_width(config, scanner) + found_paths = scanner.find_files_by_name( share_path=config['share'], file_name='每公里指标明细表*.xls', @@ -780,7 +820,7 @@ def get_road_dict(dir,user_name,pwd) : road_dict = {} if len(found_paths) > 0 : df = scanner.read_excel(found_paths[0]) - rows = scanner.process_all_rows(df) + rows = scanner.process_all_rows(df, columns=['线路编码','区划代码','识别宽度(米)','方向(上行/下行)','技术等级','路面类型(沥青/水泥/砂石)','起桩号(米)','止桩号(米)']) for i, row in enumerate(rows, 1): data = row['data'] if pd.notna(data['线路编码']) : @@ -788,6 +828,9 @@ def get_road_dict(dir,user_name,pwd) : if data['方向(上行/下行)'] == '下行' : up_or_down = 'B' key = f"{data['线路编码']}{str(int(data['区划代码']))}{up_or_down}" + width = road_dict_for_width.get(key) + if width : + data['识别宽度(米)'] = width if road_dict.get(key) : road_dict[key].append(row) else : @@ -795,6 +838,29 @@ def get_road_dict(dir,user_name,pwd) : return road_dict +def get_road_dict_for_width(config, scanner): + found_paths_for_width = scanner.find_files_by_name( + share_path=config['share'], + file_name='nl_yy_glzb*.xls', + start_dir=config['dir'], + max_depth=4 + ) + road_dict_for_width = {} + if len(found_paths_for_width) > 0 : + df = scanner.read_excel(found_paths_for_width[0]) + rows = scanner.process_all_rows(df, columns=['线路编码','区划代码','方向(上行/下行)','识别宽度(米)']) + for i, row in enumerate(rows, 1): + data = row['data'] + if pd.notna(data['线路编码']) : + up_or_down = 'A' + if data['方向(上行/下行)'] == '下行' : + up_or_down = 'B' + key = f"{data['线路编码']}{str(int(data['区划代码']))}{up_or_down}" + width = data.get('识别宽度(米)') + if width : + road_dict_for_width[key] = width + return road_dict_for_width + # filename -> 桩号 def get_pile_dict(dir,user_name,pwd) : config = get_conf(dir, user_name, pwd) @@ -1030,8 +1096,42 @@ def main(): # flag = scanner.directory_exists(path) # print(f"flag={flag}") - str = '\\\\192.168.110.114\\share_File\\西南计算机/北碚报送数据/3.25/图像类\\C071500109B\\Images_识别/77/20251120172316/77.zip' - print(f"standardized_path(str)={standardized_path(str)}") + # str = '\\\\192.168.110.114\\share_File\\西南计算机/北碚报送数据/3.25/图像类\\C071500109B\\Images_识别/77/20251120172316/77.zip' + # print(f"standardized_path(str)={standardized_path(str)}") + + road_dict_for_width = {} + df = pd.read_excel('D:/devForBdzlWork/ai-train_platform/predictions/C005500155A/nl_yy_glzb (2)_识别宽度数据读取.xls') + rows = scanner.process_all_rows(df, columns=['线路编码','区划代码','方向(上行/下行)','识别宽度(米)']) + for i, row in enumerate(rows, 1): + data = row['data'] + if pd.notna(data['线路编码']) : + up_or_down = 'A' + if data['方向(上行/下行)'] == '下行' : + up_or_down = 'B' + key = f"{data['线路编码']}{str(int(data['区划代码']))}{up_or_down}" + width = data.get('识别宽度(米)') + if width : + road_dict_for_width[key] = width + print('读取excel') + + road_dict = {} + df = pd.read_excel('D:/devForBdzlWork/ai-train_platform/predictions/C005500155A/每公里指标明细表-农村公路技术状况评定结果表-500155.xls') + rows = scanner.process_all_rows(df, columns=['线路编码','区划代码','识别宽度(米)','方向(上行/下行)','技术等级','路面类型(沥青/水泥/砂石)','起桩号(米)','止桩号(米)']) + for i, row in enumerate(rows, 1): + data = row['data'] + if pd.notna(data['线路编码']) : + up_or_down = 'A' + if data['方向(上行/下行)'] == '下行' : + up_or_down = 'B' + key = f"{data['线路编码']}{str(int(data['区划代码']))}{up_or_down}" + width = road_dict_for_width.get(key) + if width : + data['识别宽度(米)'] = width + if road_dict.get(key) : + road_dict[key].append(row) + else : + road_dict[key] = [row] + print('读取excel') if __name__ == "__main__": main() \ No newline at end of file diff --git a/util/yolo2pix_new.py b/util/yolo2pix_new.py index 92c9362..ec9ebe4 100644 --- a/util/yolo2pix_new.py +++ b/util/yolo2pix_new.py @@ -536,7 +536,7 @@ def format_number_to_k_code(number): return f"K{integer_part:04d}+{decimal_part}" # ---------------- 主函数-共享目录 ---------------- -def process_dir(road_dict,pile_dict,dir="output",cell_area=CELL_AREA,grid_width=GRID_WIDTH,grid_height=GRID_HEIGHT): +def process_dir(road_dict,pile_dict,dir="output",cell_area=CELL_AREA,cell_width=CELL_WIDTH,cell_height=CELL_HEIGHT,grid_width=GRID_WIDTH,grid_height=GRID_HEIGHT): os.makedirs(dir,exist_ok=True) # 解压 # 读取桩号映射 @@ -556,7 +556,7 @@ def process_dir(road_dict,pile_dict,dir="output",cell_area=CELL_AREA,grid_width= with open(grid_txt_path,'w',encoding='utf-8') as f: f.write(out_txt) # 生成网格可视化 - draw_grid_on_image(image_path,class_cells,cell_size=(GRID_WIDTH, GRID_HEIGHT),save_path=os.path.splitext(image_path)[0]+"_grid.jpg") + # draw_grid_on_image(image_path,class_cells,cell_size=(GRID_WIDTH, GRID_HEIGHT),save_path=os.path.splitext(image_path)[0]+"_grid.jpg") # 统计各类面积 counts = {k:[len(v[0])*cell_area, v[1][0], v[1][1]] for k,v in class_cells.items()} # total_area = sum(counts.values()) @@ -586,7 +586,7 @@ def process_dir(road_dict,pile_dict,dir="output",cell_area=CELL_AREA,grid_width= process_damage_txt(road_dict, pile_dict, dir, summary_data, image_path, current_time) # 病害明细列表.xlsx - img_file_path = process_damage_detail_excel(road_dict, pile_dict, dir, cell_area, summary_data, image_path) + img_file_path = process_damage_detail_excel(road_dict, pile_dict, dir, cell_area, cell_width, cell_height, summary_data, image_path) # 综合明细表.xlsx process_damage_composite_excel(road_dict, pile_dict, summary_data, image_path, current_time, img_file_path) @@ -650,7 +650,7 @@ def process_damage_detail_txt(road_dict, pile_dict, dir, summary_data, image_pat print(f"输出完成: {out_file}") -def process_damage_detail_excel(road_dict, pile_dict, dir, cell_area, summary_data, image_path): +def process_damage_detail_excel(road_dict, pile_dict, dir, cell_area, cell_width, cell_height, summary_data, image_path): print("输出:病害明细列表.xlsx") os.makedirs(f"{dir}/excel", exist_ok=True) headers = ['序号','路线编码','方向','桩号','路面类型','病害名称','程度','长度(m)',' 宽度(m)',' 面积(㎡)',' 横向位置','备注'] @@ -666,7 +666,7 @@ def process_damage_detail_excel(road_dict, pile_dict, dir, cell_area, summary_da for data in summary_data: damage_data = data[2] for attr_name, attr_value in damage_data.items(): - excel_data = [excel_index, road_code, up_or_down, f"K000{data[0]}", ROAD_TYPE_EN_TO_CN.get(road_type), attr_name, '', attr_value[1]*cell_area, attr_value[2]*cell_area, attr_value[0], '', ''] + excel_data = [excel_index, road_code, up_or_down, f"K000{data[0]}", ROAD_TYPE_EN_TO_CN.get(road_type), attr_name, '', attr_value[1]*cell_width, attr_value[2]*cell_height, attr_value[0], '', ''] data_list.append(excel_data) all_data = [headers] + data_list @@ -961,10 +961,14 @@ if __name__=="__main__": # calc_cell_area, calc_grid_width, calc_grid_height = calc_grid_param(2048, 4096, 3.6, 2) # print(f"calc_cell_area={calc_cell_area}, calc_grid_width={calc_grid_width}, calc_grid_height={calc_grid_height}") - output_dir = "D:/devForBdzlWork/ai-train_platform/predictions/CV78500155B" + output_dir = "D:/devForBdzlWork/ai-train_platform/predictions/C234500155A" pile_dict = get_pile_dict(output_dir) road_dict = get_road_dict(output_dir) process_dir(road_dict, pile_dict, output_dir) + # arr = [44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 68, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179] + # for a in arr : + # print(f"a = {a} x = {a % 37} y = {int(a / 37)}") +