132 lines
5.3 KiB
Python
132 lines
5.3 KiB
Python
import yaml
|
||
from dataclasses import dataclass
|
||
from typing import List, Dict, Optional
|
||
|
||
local_config = "yml/local_yolo_config.yml"
|
||
# local_config = r"D:\project\AI-PYTHON\Ai_tottle\yml\local_yolo_config.yml"
|
||
|
||
|
||
@dataclass
|
||
class YmlLocalFunc:
|
||
model_func_id: int # 对应 PostgreSQL 表的 model_func_Id 字段
|
||
filter_cls: List[int] # 过滤的类别索引列表
|
||
func_description: str # 功能描述
|
||
yolo_version: str # YOLO 模型版本
|
||
path: str # 模型路径
|
||
cls_index: List[int] # 类别索引列表
|
||
cls: List[str] # 类别名称(中文)
|
||
cls_en: List[str] # 类别名称(英文)
|
||
cls_description: str # 类别描述
|
||
cls_en_dict: Dict[int, str] = None # 英文类别字典(索引 -> 名称)
|
||
cls_dict: Dict[int, str] = None # 中文类别字典(索引 -> 名称)
|
||
cls_str_dict: Dict[str, str] = None # 中文类别字典(索引 -> 名称)
|
||
# filter_cls_en_dict: Dict[int, str] = None # 根据过滤字段,配置字典
|
||
# filter_cls_dict: Dict[int, str] = None # 根据过滤字段,配置字典
|
||
cls_zn_to_eh_dict: Optional[Dict[str, str]] = None # 新增:英文 -> 中文的映射字典
|
||
filtered_cls_dict: Optional[Dict[int, str]] = None # 过滤后的中文类别字典(索引 -> 名称)
|
||
filtered_cls_en_dict: Optional[Dict[int, str]] = None # 过滤后的英文类别字典(索引 -> 名称)
|
||
|
||
def __post_init__(self):
|
||
"""初始化后自动生成类别字典"""
|
||
if self.cls_en_dict is None:
|
||
self.cls_en_dict = {idx: name for idx, name in zip(self.cls_index, self.cls_en)}
|
||
if self.cls_dict is None:
|
||
self.cls_dict = {idx: name for idx, name in zip(self.cls_index, self.cls)}
|
||
if self.cls_str_dict is None:
|
||
self.cls_str_dict = {idx: name for idx, name in zip(self.cls, self.cls_en)}
|
||
|
||
valid_filter_cls = [idx for idx in self.filter_cls if idx in self.cls_index]
|
||
|
||
# 生成过滤后的类别字典
|
||
self.filtered_cls_dict = {
|
||
idx: self.cls[self.cls_index.index(idx)]
|
||
for idx in valid_filter_cls
|
||
}
|
||
|
||
self.filtered_cls_en_dict = {
|
||
idx: self.cls_en[self.cls_index.index(idx)]
|
||
for idx in valid_filter_cls
|
||
}
|
||
|
||
|
||
|
||
valid_pairs = []
|
||
for idx in self.filter_cls:
|
||
if idx in self.cls_index:
|
||
en_name = self.cls_en[self.cls_index.index(idx)]
|
||
zh_name = self.cls[self.cls_index.index(idx)]
|
||
valid_pairs.append((zh_name,en_name))
|
||
|
||
# 2. 构建字典,处理多个英文名对应同一中文名的情况
|
||
self.cls_zn_to_eh_dict = {}
|
||
for en_name, zh_name in valid_pairs:
|
||
if zh_name not in self.cls_zn_to_eh_dict.values():
|
||
# 如果中文名还未在字典的值中,直接添加
|
||
self.cls_zn_to_eh_dict[en_name] = zh_name
|
||
|
||
|
||
def read_local_func_config() -> List[YmlLocalFunc]:
|
||
"""读取 YAML 文件并解析为 YmlLocalFunc 对象列表"""
|
||
with open(local_config, "r", encoding="utf-8") as file:
|
||
config_data = yaml.safe_load(file)
|
||
|
||
# 解析 local_func 列表
|
||
local_funcs = []
|
||
for func_data in config_data["local_func"]:
|
||
# 直接初始化 YmlLocalFunc(字段名需与 YAML 完全匹配)
|
||
local_func = YmlLocalFunc(**func_data)
|
||
local_funcs.append(local_func)
|
||
|
||
return local_funcs
|
||
|
||
|
||
|
||
def get_local_func_by_id(model_func_id: int) -> Optional[YmlLocalFunc]:
|
||
"""通过 model_func_id 获取对应的 YmlLocalFunc 对象"""
|
||
local_funcs = read_local_func_config()
|
||
for func in local_funcs:
|
||
if func.model_func_id == model_func_id:
|
||
return func
|
||
return None # 如果未找到,返回 None
|
||
|
||
# 基于方法id和输入的category(filter_cls) 做过滤
|
||
def get_local_func_by_id_and_category(model_func_id: int,category:list[int]) -> Optional[YmlLocalFunc]:
|
||
"""通过 model_func_id 获取对应的 YmlLocalFunc 对象"""
|
||
local_funcs = read_local_func_config()
|
||
for func in local_funcs:
|
||
if func.model_func_id == model_func_id:
|
||
func.filter_cls=category
|
||
valid_filter_cls = [idx for idx in func.filter_cls if idx in func.cls_index]
|
||
|
||
# 生成过滤后的类别字典
|
||
func.filtered_cls_dict = {
|
||
idx: func.cls[func.cls_index.index(idx)]
|
||
for idx in valid_filter_cls
|
||
}
|
||
|
||
func.filtered_cls_en_dict = {
|
||
idx: func.cls_en[func.cls_index.index(idx)]
|
||
for idx in valid_filter_cls
|
||
}
|
||
|
||
valid_pairs = []
|
||
for idx in func.filter_cls:
|
||
if idx in func.cls_index:
|
||
en_name = func.cls_en[func.cls_index.index(idx)]
|
||
zh_name = func.cls[func.cls_index.index(idx)]
|
||
valid_pairs.append((zh_name, en_name))
|
||
|
||
# 2. 构建字典,处理多个英文名对应同一中文名的情况
|
||
func.cls_zn_to_eh_dict = {}
|
||
for en_name, zh_name in valid_pairs:
|
||
if zh_name not in func.cls_zn_to_eh_dict.values():
|
||
# 如果中文名还未在字典的值中,直接添加
|
||
func.cls_zn_to_eh_dict[en_name] = zh_name
|
||
return func
|
||
return None # 如果未找到,返回 None
|
||
|
||
|
||
|
||
# result = read_local_func_config()
|
||
# # print()
|