import yaml from dataclasses import dataclass from typing import List, Dict, Optional local_config = "yml/local_yolo_config.yml" # local_config = r"D:\project\AI-PYTHON\Ai_tottle\yml\local_yolo_config.yml" @dataclass class YmlLocalFunc: model_func_id: int # 对应 PostgreSQL 表的 model_func_Id 字段 filter_cls: List[int] # 过滤的类别索引列表 func_description: str # 功能描述 yolo_version: str # YOLO 模型版本 path: str # 模型路径 cls_index: List[int] # 类别索引列表 cls: List[str] # 类别名称(中文) cls_en: List[str] # 类别名称(英文) cls_description: str # 类别描述 cls_en_dict: Dict[int, str] = None # 英文类别字典(索引 -> 名称) cls_dict: Dict[int, str] = None # 中文类别字典(索引 -> 名称) cls_str_dict: Dict[str, str] = None # 中文类别字典(索引 -> 名称) # filter_cls_en_dict: Dict[int, str] = None # 根据过滤字段,配置字典 # filter_cls_dict: Dict[int, str] = None # 根据过滤字段,配置字典 cls_zn_to_eh_dict: Optional[Dict[str, str]] = None # 新增:英文 -> 中文的映射字典 filtered_cls_dict: Optional[Dict[int, str]] = None # 过滤后的中文类别字典(索引 -> 名称) filtered_cls_en_dict: Optional[Dict[int, str]] = None # 过滤后的英文类别字典(索引 -> 名称) def __post_init__(self): """初始化后自动生成类别字典""" if self.cls_en_dict is None: self.cls_en_dict = {idx: name for idx, name in zip(self.cls_index, self.cls_en)} if self.cls_dict is None: self.cls_dict = {idx: name for idx, name in zip(self.cls_index, self.cls)} if self.cls_str_dict is None: self.cls_str_dict = {idx: name for idx, name in zip(self.cls, self.cls_en)} valid_filter_cls = [idx for idx in self.filter_cls if idx in self.cls_index] # 生成过滤后的类别字典 self.filtered_cls_dict = { idx: self.cls[self.cls_index.index(idx)] for idx in valid_filter_cls } self.filtered_cls_en_dict = { idx: self.cls_en[self.cls_index.index(idx)] for idx in valid_filter_cls } valid_pairs = [] for idx in self.filter_cls: if idx in self.cls_index: en_name = self.cls_en[self.cls_index.index(idx)] zh_name = self.cls[self.cls_index.index(idx)] valid_pairs.append((zh_name,en_name)) # 2. 构建字典,处理多个英文名对应同一中文名的情况 self.cls_zn_to_eh_dict = {} for en_name, zh_name in valid_pairs: if zh_name not in self.cls_zn_to_eh_dict.values(): # 如果中文名还未在字典的值中,直接添加 self.cls_zn_to_eh_dict[en_name] = zh_name def read_local_func_config() -> List[YmlLocalFunc]: """读取 YAML 文件并解析为 YmlLocalFunc 对象列表""" with open(local_config, "r", encoding="utf-8") as file: config_data = yaml.safe_load(file) # 解析 local_func 列表 local_funcs = [] for func_data in config_data["local_func"]: # 直接初始化 YmlLocalFunc(字段名需与 YAML 完全匹配) local_func = YmlLocalFunc(**func_data) local_funcs.append(local_func) return local_funcs def get_local_func_by_id(model_func_id: int) -> Optional[YmlLocalFunc]: """通过 model_func_id 获取对应的 YmlLocalFunc 对象""" local_funcs = read_local_func_config() for func in local_funcs: if func.model_func_id == model_func_id: return func return None # 如果未找到,返回 None # 基于方法id和输入的category(filter_cls) 做过滤 def get_local_func_by_id_and_category(model_func_id: int,category:list[int]) -> Optional[YmlLocalFunc]: """通过 model_func_id 获取对应的 YmlLocalFunc 对象""" local_funcs = read_local_func_config() for func in local_funcs: if func.model_func_id == model_func_id: func.filter_cls=category valid_filter_cls = [idx for idx in func.filter_cls if idx in func.cls_index] # 生成过滤后的类别字典 func.filtered_cls_dict = { idx: func.cls[func.cls_index.index(idx)] for idx in valid_filter_cls } func.filtered_cls_en_dict = { idx: func.cls_en[func.cls_index.index(idx)] for idx in valid_filter_cls } valid_pairs = [] for idx in func.filter_cls: if idx in func.cls_index: en_name = func.cls_en[func.cls_index.index(idx)] zh_name = func.cls[func.cls_index.index(idx)] valid_pairs.append((zh_name, en_name)) # 2. 构建字典,处理多个英文名对应同一中文名的情况 func.cls_zn_to_eh_dict = {} for en_name, zh_name in valid_pairs: if zh_name not in func.cls_zn_to_eh_dict.values(): # 如果中文名还未在字典的值中,直接添加 func.cls_zn_to_eh_dict[en_name] = zh_name return func return None # 如果未找到,返回 None # result = read_local_func_config() # # print()