ai_project_v1/middleware/read_yolo_config.py

132 lines
5.3 KiB
Python
Raw Normal View History

import yaml
from dataclasses import dataclass
from typing import List, Dict, Optional
local_config = "yml/local_yolo_config.yml"
# local_config = r"D:\project\AI-PYTHON\Ai_tottle\yml\local_yolo_config.yml"
@dataclass
class YmlLocalFunc:
model_func_id: int # 对应 PostgreSQL 表的 model_func_Id 字段
filter_cls: List[int] # 过滤的类别索引列表
func_description: str # 功能描述
yolo_version: str # YOLO 模型版本
path: str # 模型路径
cls_index: List[int] # 类别索引列表
cls: List[str] # 类别名称(中文)
cls_en: List[str] # 类别名称(英文)
cls_description: str # 类别描述
cls_en_dict: Dict[int, str] = None # 英文类别字典(索引 -> 名称)
cls_dict: Dict[int, str] = None # 中文类别字典(索引 -> 名称)
cls_str_dict: Dict[str, str] = None # 中文类别字典(索引 -> 名称)
# filter_cls_en_dict: Dict[int, str] = None # 根据过滤字段,配置字典
# filter_cls_dict: Dict[int, str] = None # 根据过滤字段,配置字典
cls_zn_to_eh_dict: Optional[Dict[str, str]] = None # 新增:英文 -> 中文的映射字典
filtered_cls_dict: Optional[Dict[int, str]] = None # 过滤后的中文类别字典(索引 -> 名称)
filtered_cls_en_dict: Optional[Dict[int, str]] = None # 过滤后的英文类别字典(索引 -> 名称)
def __post_init__(self):
"""初始化后自动生成类别字典"""
if self.cls_en_dict is None:
self.cls_en_dict = {idx: name for idx, name in zip(self.cls_index, self.cls_en)}
if self.cls_dict is None:
self.cls_dict = {idx: name for idx, name in zip(self.cls_index, self.cls)}
if self.cls_str_dict is None:
self.cls_str_dict = {idx: name for idx, name in zip(self.cls, self.cls_en)}
valid_filter_cls = [idx for idx in self.filter_cls if idx in self.cls_index]
# 生成过滤后的类别字典
self.filtered_cls_dict = {
idx: self.cls[self.cls_index.index(idx)]
for idx in valid_filter_cls
}
self.filtered_cls_en_dict = {
idx: self.cls_en[self.cls_index.index(idx)]
for idx in valid_filter_cls
}
valid_pairs = []
for idx in self.filter_cls:
if idx in self.cls_index:
en_name = self.cls_en[self.cls_index.index(idx)]
zh_name = self.cls[self.cls_index.index(idx)]
valid_pairs.append((zh_name,en_name))
# 2. 构建字典,处理多个英文名对应同一中文名的情况
self.cls_zn_to_eh_dict = {}
for en_name, zh_name in valid_pairs:
if zh_name not in self.cls_zn_to_eh_dict.values():
# 如果中文名还未在字典的值中,直接添加
self.cls_zn_to_eh_dict[en_name] = zh_name
def read_local_func_config() -> List[YmlLocalFunc]:
"""读取 YAML 文件并解析为 YmlLocalFunc 对象列表"""
with open(local_config, "r", encoding="utf-8") as file:
config_data = yaml.safe_load(file)
# 解析 local_func 列表
local_funcs = []
for func_data in config_data["local_func"]:
# 直接初始化 YmlLocalFunc字段名需与 YAML 完全匹配)
local_func = YmlLocalFunc(**func_data)
local_funcs.append(local_func)
return local_funcs
def get_local_func_by_id(model_func_id: int) -> Optional[YmlLocalFunc]:
"""通过 model_func_id 获取对应的 YmlLocalFunc 对象"""
local_funcs = read_local_func_config()
for func in local_funcs:
if func.model_func_id == model_func_id:
return func
return None # 如果未找到,返回 None
# 基于方法id和输入的categoryfilter_cls 做过滤
def get_local_func_by_id_and_category(model_func_id: int,category:list[int]) -> Optional[YmlLocalFunc]:
"""通过 model_func_id 获取对应的 YmlLocalFunc 对象"""
local_funcs = read_local_func_config()
for func in local_funcs:
if func.model_func_id == model_func_id:
func.filter_cls=category
valid_filter_cls = [idx for idx in func.filter_cls if idx in func.cls_index]
# 生成过滤后的类别字典
func.filtered_cls_dict = {
idx: func.cls[func.cls_index.index(idx)]
for idx in valid_filter_cls
}
func.filtered_cls_en_dict = {
idx: func.cls_en[func.cls_index.index(idx)]
for idx in valid_filter_cls
}
valid_pairs = []
for idx in func.filter_cls:
if idx in func.cls_index:
en_name = func.cls_en[func.cls_index.index(idx)]
zh_name = func.cls[func.cls_index.index(idx)]
valid_pairs.append((zh_name, en_name))
# 2. 构建字典,处理多个英文名对应同一中文名的情况
func.cls_zn_to_eh_dict = {}
for en_name, zh_name in valid_pairs:
if zh_name not in func.cls_zn_to_eh_dict.values():
# 如果中文名还未在字典的值中,直接添加
func.cls_zn_to_eh_dict[en_name] = zh_name
return func
return None # 如果未找到,返回 None
# result = read_local_func_config()
# # print()