import psycopg2 from psycopg2.extras import RealDictCursor import json from typing import Dict, List, Union, Optional from dataclasses import dataclass, asdict from datetime import datetime import re @dataclass class ModelClassInfo: index: int name: str english_name: Optional[str] = None description: Optional[str] = None @dataclass class ClassConfig: filter_indices: List[int] class_indices: List[int] classes: List[ModelClassInfo] @dataclass class ModelInfo: id: int yolo_version: str model_path: str func_description: Optional[str] = None @dataclass class ModelMetadata: total_classes: int created_at: Optional[datetime] = None updated_at: Optional[datetime] = None @dataclass class ModelData: id: int yolo_version: str model_path: str repeat_dis: float func_description: Optional[str] filter_indices: List[int] class_indices: List[int] conf: float classes: List[ModelClassInfo] total_classes: int cls_names: {} filtered_cls_en_dict: {} cls_en_dict: {} filtered_cls_dict: {} cls_dict: {} cls_str_dict: {} cls_zn_to_eh_dict: {} allowed_classes: [] created_at: Optional[datetime] = None updated_at: Optional[datetime] = None @dataclass class MqttData: mqtt_id: int mqtt_ip: str mqtt_port: int mqtt_topic: str mqtt_username: str mqtt_pass: str mqtt_description: str org_code: str mqtt_type: str @dataclass class Device: dname: str sn: str orgcode: int lat: float lng: float height: float @dataclass class ModelConfiguration: model_info: ModelInfo class_config: ClassConfig metadata: ModelMetadata @dataclass class Dataset: id: int resource_original_path: str pic_name: str local_path: str label_name: str label_content: str label_txt_path: str @dataclass class Labels: id: int resource_original_path: str resource_id: int label_set_id: int label_ids: int annotation_data: str @dataclass class Label_Yaml: id: int id_order:int name: str e_name: str class DateTimeEncoder(json.JSONEncoder): """自定义JSON编码器,用于处理datetime对象""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj) class ModelConfigDAO: def __init__(self, db_params: Dict[str, str]): """ 初始化数据库连接 参数: db_params: 数据库连接参数,包含: - dbname: 数据库名 - user: 用户名 - password: 密码 - host: 主机地址 - port: 端口号 """ self.db_params = db_params def insert_config(self, config: ModelConfiguration) -> bool: """ 插入新的模型配置 参数: config: 要插入的模型配置对象 返回: 是否插入成功 """ if not isinstance(config, ModelConfiguration): raise ValueError("Invalid configuration type") # 将对象转换为数据库格式 data = self._config_to_db_format(config) query = """ INSERT INTO ai_model ( model_id, filter_cls, func_description, yolo_version, path, cls_index, cls, cls_en, cls_description ) VALUES ( %(model_id)s, %(filter_cls)s, %(func_description)s, %(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s ) """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(query, data) conn.commit() return True except psycopg2.Error as e: print(f"Database insert error: {e}") return False def update_config(self, config: ModelConfiguration) -> bool: """ 更新现有的模型配置 参数: config: 要更新的模型配置对象 返回: 是否更新成功 """ if not isinstance(config, ModelConfiguration): raise ValueError("Invalid configuration type") data = self._config_to_db_format(config) query = """ UPDATE ai_model SET filter_cls = %(filter_cls)s, func_description = %(func_description)s, yolo_version = %(yolo_version)s, path = %(path)s, cls_index = %(cls_index)s, cls = %(cls)s, cls_en = %(cls_en)s, cls_description = %(cls_description)s WHERE model_id = %(model_id)s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(query, data) conn.commit() return True except psycopg2.Error as e: print(f"Database update error: {e}") return False def get_datasets(self, bz_training_task_id: int) -> List[Dataset]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ select bpra.id,bpra.resource_original_path from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid left join bz_datasets c on c.id =a.datasetid left join bz_dataset_project_relations d on d.data_set_id =c.id left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id where b.id=%s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (bz_training_task_id,)) rows = cur.fetchall() return [self._db_row_to_dataset(row) for row in rows] # 转换为Dataset列表 except psycopg2.Error as e: print(f"Database query error: {e}") return None def get_labels(self, bz_training_task_id: int) -> List[Labels]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ select bpra.id,bpra.resource_original_path ,bar.resource_id , bar.label_set_id ,bar.label_id ,bar.annotation_data from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid left join bz_datasets c on c.id =a.datasetid left join bz_dataset_project_relations d on d.data_set_id =c.id left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id left join bz_annotation_record bar on bar.task_assignment_id =bpra.id where b.id=%s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (bz_training_task_id,)) rows = cur.fetchall() return [self._db_row_to_labels(row) for row in rows] # 转换为Dataset列表 except psycopg2.Error as e: print(f"Database query error: {e}") return None def get_label_yaml(self, bz_training_task_id: int) -> List[Label_Yaml]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ select id,name,e_name from bz_labels where id in (select distinct(bar.label_id) AS id from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid left join bz_datasets c on c.id =a.datasetid left join bz_dataset_project_relations d on d.data_set_id =c.id left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id left join bz_annotation_record bar on bar.task_assignment_id =bpra.id where b.id=%s ) """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (bz_training_task_id,)) rows = cur.fetchall() return [self._db_row_to_label_yaml(row) for row in rows] # 转换为Dataset列表 except psycopg2.Error as e: print(f"Database query error: {e}") return None def _config_to_db_format(self, config: ModelConfiguration) -> Dict: """将配置对象转换为数据库格式""" return { "model_id": config.model_info.id, "filter_cls": config.class_config.filter_indices, "func_description": config.model_info.func_description, "yolo_version": config.model_info.yolo_version, "path": config.model_info.model_path, "cls_index": config.class_config.class_indices, "cls": [cls_info.name for cls_info in config.class_config.classes], "cls_en": [cls_info.english_name for cls_info in config.class_config.classes], "cls_description": ", ".join( filter(None, [cls_info.description for cls_info in config.class_config.classes]) ) } def insert_train_pid(self, task_id, train_pid) -> bool: """ 插入新的训练记录(task_id 和 train_pid) 参数: config: 要插入的模型配置对象(需包含 task_id 和 train_pid) 返回: 是否插入成功 """ insert_sql = """ INSERT INTO bz_train_record ( task_id, train_pid,create_time ) VALUES ( %s, %s,now() ) """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(insert_sql, (task_id, train_pid)) conn.commit() return True except psycopg2.Error as e: print(f"Database insert error: {e}") return False def _db_row_to_dataset(self, row: Dict) -> ModelConfiguration: return Dataset( id=row["id"], resource_original_path=row["resource_original_path"], pic_name=None, local_path=None, label_name=None, label_content="", label_txt_path=None ) def _db_row_to_labels(self, row: Dict) -> ModelConfiguration: return Labels( id=row["id"], resource_original_path=row["resource_original_path"], resource_id=row["resource_id"], label_set_id=row["label_set_id"], label_ids=row["label_id"], annotation_data=row["annotation_data"] ) def _db_row_to_label_yaml(self, row: Dict) -> ModelConfiguration: return Label_Yaml( id=row["id"], id_order=-1, name=row["name"], e_name=row["e_name"] )