ai_project_v1/middleware/read_yolo_config.py

132 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import yaml
from dataclasses import dataclass
from typing import List, Dict, Optional
local_config = "yml/local_yolo_config.yml"
# local_config = r"D:\project\AI-PYTHON\Ai_tottle\yml\local_yolo_config.yml"
@dataclass
class YmlLocalFunc:
model_func_id: int # 对应 PostgreSQL 表的 model_func_Id 字段
filter_cls: List[int] # 过滤的类别索引列表
func_description: str # 功能描述
yolo_version: str # YOLO 模型版本
path: str # 模型路径
cls_index: List[int] # 类别索引列表
cls: List[str] # 类别名称(中文)
cls_en: List[str] # 类别名称(英文)
cls_description: str # 类别描述
cls_en_dict: Dict[int, str] = None # 英文类别字典(索引 -> 名称)
cls_dict: Dict[int, str] = None # 中文类别字典(索引 -> 名称)
cls_str_dict: Dict[str, str] = None # 中文类别字典(索引 -> 名称)
# filter_cls_en_dict: Dict[int, str] = None # 根据过滤字段,配置字典
# filter_cls_dict: Dict[int, str] = None # 根据过滤字段,配置字典
cls_zn_to_eh_dict: Optional[Dict[str, str]] = None # 新增:英文 -> 中文的映射字典
filtered_cls_dict: Optional[Dict[int, str]] = None # 过滤后的中文类别字典(索引 -> 名称)
filtered_cls_en_dict: Optional[Dict[int, str]] = None # 过滤后的英文类别字典(索引 -> 名称)
def __post_init__(self):
"""初始化后自动生成类别字典"""
if self.cls_en_dict is None:
self.cls_en_dict = {idx: name for idx, name in zip(self.cls_index, self.cls_en)}
if self.cls_dict is None:
self.cls_dict = {idx: name for idx, name in zip(self.cls_index, self.cls)}
if self.cls_str_dict is None:
self.cls_str_dict = {idx: name for idx, name in zip(self.cls, self.cls_en)}
valid_filter_cls = [idx for idx in self.filter_cls if idx in self.cls_index]
# 生成过滤后的类别字典
self.filtered_cls_dict = {
idx: self.cls[self.cls_index.index(idx)]
for idx in valid_filter_cls
}
self.filtered_cls_en_dict = {
idx: self.cls_en[self.cls_index.index(idx)]
for idx in valid_filter_cls
}
valid_pairs = []
for idx in self.filter_cls:
if idx in self.cls_index:
en_name = self.cls_en[self.cls_index.index(idx)]
zh_name = self.cls[self.cls_index.index(idx)]
valid_pairs.append((zh_name,en_name))
# 2. 构建字典,处理多个英文名对应同一中文名的情况
self.cls_zn_to_eh_dict = {}
for en_name, zh_name in valid_pairs:
if zh_name not in self.cls_zn_to_eh_dict.values():
# 如果中文名还未在字典的值中,直接添加
self.cls_zn_to_eh_dict[en_name] = zh_name
def read_local_func_config() -> List[YmlLocalFunc]:
"""读取 YAML 文件并解析为 YmlLocalFunc 对象列表"""
with open(local_config, "r", encoding="utf-8") as file:
config_data = yaml.safe_load(file)
# 解析 local_func 列表
local_funcs = []
for func_data in config_data["local_func"]:
# 直接初始化 YmlLocalFunc字段名需与 YAML 完全匹配)
local_func = YmlLocalFunc(**func_data)
local_funcs.append(local_func)
return local_funcs
def get_local_func_by_id(model_func_id: int) -> Optional[YmlLocalFunc]:
"""通过 model_func_id 获取对应的 YmlLocalFunc 对象"""
local_funcs = read_local_func_config()
for func in local_funcs:
if func.model_func_id == model_func_id:
return func
return None # 如果未找到,返回 None
# 基于方法id和输入的categoryfilter_cls 做过滤
def get_local_func_by_id_and_category(model_func_id: int,category:list[int]) -> Optional[YmlLocalFunc]:
"""通过 model_func_id 获取对应的 YmlLocalFunc 对象"""
local_funcs = read_local_func_config()
for func in local_funcs:
if func.model_func_id == model_func_id:
func.filter_cls=category
valid_filter_cls = [idx for idx in func.filter_cls if idx in func.cls_index]
# 生成过滤后的类别字典
func.filtered_cls_dict = {
idx: func.cls[func.cls_index.index(idx)]
for idx in valid_filter_cls
}
func.filtered_cls_en_dict = {
idx: func.cls_en[func.cls_index.index(idx)]
for idx in valid_filter_cls
}
valid_pairs = []
for idx in func.filter_cls:
if idx in func.cls_index:
en_name = func.cls_en[func.cls_index.index(idx)]
zh_name = func.cls[func.cls_index.index(idx)]
valid_pairs.append((zh_name, en_name))
# 2. 构建字典,处理多个英文名对应同一中文名的情况
func.cls_zn_to_eh_dict = {}
for en_name, zh_name in valid_pairs:
if zh_name not in func.cls_zn_to_eh_dict.values():
# 如果中文名还未在字典的值中,直接添加
func.cls_zn_to_eh_dict[en_name] = zh_name
return func
return None # 如果未找到,返回 None
# result = read_local_func_config()
# # print()