import psycopg2 from psycopg2.extras import RealDictCursor import json from typing import Dict, List, Union, Optional from dataclasses import dataclass, asdict from datetime import datetime import re @dataclass class ModelClassInfo: index: int name: str english_name: Optional[str] = None description: Optional[str] = None @dataclass class ClassConfig: filter_indices: List[int] class_indices: List[int] classes: List[ModelClassInfo] @dataclass class ModelInfo: id: int yolo_version: str model_path: str func_description: Optional[str] = None @dataclass class ModelMetadata: total_classes: int created_at: Optional[datetime] = None updated_at: Optional[datetime] = None @dataclass class ModelData: id: int yolo_version: str model_path: str engine_path: str so_path: str repeat_dis: float repeat_time: float func_description: Optional[str] filter_indices: List[int] class_indices: List[int] conf: float classes: List[ModelClassInfo] total_classes: int cls_names: {} filtered_cls_en_dict: {} cls_en_dict: {} filtered_cls_dict: {} cls_dict: {} cls_str_dict: {} cls_zn_to_eh_dict: {} allowed_classes: [] created_at: Optional[datetime] = None updated_at: Optional[datetime] = None @dataclass class MqttData: mqtt_id: int mqtt_ip: str mqtt_port: int mqtt_topic: str mqtt_username: str mqtt_pass: str mqtt_description: str org_code: str mqtt_type: str @dataclass class Device: dname: str sn: str orgcode: int lat: float lng: float height: float @dataclass class ModelConfiguration: model_info: ModelInfo class_config: ClassConfig metadata: ModelMetadata class DateTimeEncoder(json.JSONEncoder): """自定义JSON编码器,用于处理datetime对象""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj) class ModelConfigDAO: def __init__(self, db_params: Dict[str, str]): """ 初始化数据库连接 参数: db_params: 数据库连接参数,包含: - dbname: 数据库名 - user: 用户名 - password: 密码 - host: 主机地址 - port: 端口号 """ self.db_params = db_params self._create_table_if_not_exists() def _create_table_if_not_exists(self): """确保表存在""" create_table_sql = """ CREATE TABLE IF NOT EXISTS ai_model ( model_id INTEGER PRIMARY KEY, filter_cls INTEGER[], func_description TEXT, yolo_version TEXT , path TEXT , cls_index INTEGER[] , cls TEXT[] , cls_en TEXT[] , cls_description TEXT, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); CREATE OR REPLACE FUNCTION update_updated_at_column() RETURNS TRIGGER AS $$ BEGIN NEW.updated_at = CURRENT_TIMESTAMP; RETURN NEW; END; $$ LANGUAGE plpgsql; DROP TRIGGER IF EXISTS update_model_config_modtime ON ai_model; CREATE TRIGGER update_model_config_modtime BEFORE UPDATE ON ai_model FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(create_table_sql) conn.commit() except psycopg2.Error as e: print(f"Error creating table: {e}") raise def insert_config(self, config: ModelConfiguration) -> bool: """ 插入新的模型配置 参数: config: 要插入的模型配置对象 返回: 是否插入成功 """ if not isinstance(config, ModelConfiguration): raise ValueError("Invalid configuration type") # 将对象转换为数据库格式 data = self._config_to_db_format(config) query = """ INSERT INTO ai_model ( model_id, filter_cls, func_description, yolo_version, path, cls_index, cls, cls_en, cls_description ) VALUES ( %(model_id)s, %(filter_cls)s, %(func_description)s, %(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s ) """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(query, data) conn.commit() return True except psycopg2.Error as e: print(f"Database insert error: {e}") return False def update_config(self, config: ModelConfiguration) -> bool: """ 更新现有的模型配置 参数: config: 要更新的模型配置对象 返回: 是否更新成功 """ if not isinstance(config, ModelConfiguration): raise ValueError("Invalid configuration type") data = self._config_to_db_format(config) query = """ UPDATE ai_model SET filter_cls = %(filter_cls)s, func_description = %(func_description)s, yolo_version = %(yolo_version)s, path = %(path)s, cls_index = %(cls_index)s, cls = %(cls)s, cls_en = %(cls_en)s, cls_description = %(cls_description)s WHERE model_id = %(model_id)s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(query, data) conn.commit() return True except psycopg2.Error as e: print(f"Database update error: {e}") return False def get_config(self, model_id: int, filter_indices: []) -> Optional[ModelConfiguration]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ SELECT aml.func_id, aml.model_func_id, aml.model_id, aml.confidence, aml.py_func, aml.repeat_dis, aml.repeat_time, am.scope, am.yolo_version, am.PATH, am.engine_path, am.so_path, am.cls_index, am.cls, am.cls_en, am.cls_description, am.created_at, am.updated_at FROM ai_model am, ai_model_list aml WHERE aml.func_id = %s AND aml.model_id = am.model_id """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (model_id,)) result = cur.fetchone() return self._db_row_to_config(result, filter_indices) if result else None except psycopg2.Error as e: print(f"Database query error: {e}") return None def insert_request_log(self, task_id, sn, org_code, requset_json, request) -> bool: """ 插入新的模型配置 参数: config: 要插入的模型配置对象 返回: 是否插入成功 """ query = """ insert into ai_request_log(task_id,sn,org_code,requset_json,request,create_time)values(%s,%s,%s,%s,%s,now()) """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor() as cur: cur.execute(query, (task_id, sn, org_code, requset_json, request,)) conn.commit() return True except psycopg2.Error as e: print(f"Database insert error: {e}") return False def get_device(self, sn: str, orgcode: str) -> Optional[Device]: """ 获取并解析模型配置 参数: sn: 机场sn orgcode: 组织id 返回: 结构化的模型配置或None(如果未找到) """ query = """ select dname,sn,orgcode,lat,lng,height from device where sn=%s and orgcode=%s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (sn, orgcode,)) result = cur.fetchone() return self._db_row_to_device_config(result) if result else None except psycopg2.Error as e: print(f"Database query error: {e}") return None def get_mqtt_config(self, mqtt_id: int) -> Optional[ModelConfiguration]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ select id,mqtt_ip,mqtt_port,mqtt_topic,mqtt_username,mqtt_pass,description from ai_mqtt_config where id=%s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (mqtt_id,)) result = cur.fetchone() return self._db_row_to_mqtt_config(result) if result else None except psycopg2.Error as e: print(f"Database query error: {e}") return None def get_pic_mqtt_config_by_orgcode(self, org_code: str,mqtt_type: str) -> Optional[MqttData]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ select id,org_code,mqtt_type,mqtt_ip,mqtt_port, mqtt_topic,mqtt_username,mqtt_pass,description from ai_mqtt_config where org_code=%s and mqtt_type=%s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (org_code,mqtt_type,)) result = cur.fetchone() return self._db_row_to_mqtt_config_by_orgcode(result) if result else None except psycopg2.Error as e: print(f"Database query error: {e}") return None def get_mqtt_config_by_orgcode(self, org_code: str, sn: str,mqtt_type: str) -> Optional[MqttData]: """ 获取并解析模型配置 参数: model_id: 模型功能ID 返回: 结构化的模型配置或None(如果未找到) """ query = """ select id,org_code,mqtt_type,mqtt_ip,mqtt_port, mqtt_topic,mqtt_username,mqtt_pass,description from ai_mqtt_config where org_code=%s and sn=%s and mqtt_type=%s """ try: with psycopg2.connect(**self.db_params) as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, (org_code, sn,mqtt_type,)) result = cur.fetchone() return self._db_row_to_mqtt_config_by_orgcode(result) if result else None except psycopg2.Error as e: print(f"Database query error: {e}") return None def _config_to_db_format(self, config: ModelConfiguration) -> Dict: """将配置对象转换为数据库格式""" return { "model_id": config.model_info.id, "filter_cls": config.class_config.filter_indices, "func_description": config.model_info.func_description, "yolo_version": config.model_info.yolo_version, "path": config.model_info.model_path, "cls_index": config.class_config.class_indices, "cls": [cls_info.name for cls_info in config.class_config.classes], "cls_en": [cls_info.english_name for cls_info in config.class_config.classes], "cls_description": ", ".join( filter(None, [cls_info.description for cls_info in config.class_config.classes]) ) } def _db_row_to_device_config(self, row: Dict) -> ModelConfiguration: return Device( # dname, sn, orgcode, lat, lng, height dname=row["dname"], sn=row["sn"], orgcode=row["orgcode"], lat=float(row["lat"]), lng=float(row["lng"]), height=float(row["height"]) ) def _db_row_to_mqtt_config(self, row: Dict) -> ModelConfiguration: return MqttData( mqtt_id=row["id"], mqtt_ip=row["mqtt_ip"], mqtt_port=int(row["mqtt_port"]), mqtt_topic=row["mqtt_topic"], mqtt_username=row["mqtt_username"], mqtt_pass=row["mqtt_pass"], mqtt_description=row["description"] ) def _db_row_to_mqtt_config_by_orgcode(self, row: Dict) -> ModelConfiguration: return MqttData( mqtt_id=row["id"], mqtt_ip=row["mqtt_ip"], mqtt_port=int(row["mqtt_port"]), mqtt_topic=row["mqtt_topic"], mqtt_username=row["mqtt_username"], mqtt_pass=row["mqtt_pass"], mqtt_description=row["description"], org_code=row["org_code"], mqtt_type=row["mqtt_type"] ) def _db_row_to_config(self, row: Dict, filter_cls: []) -> ModelConfiguration: """将数据库行转换为配置对象""" # 解析描述字符串为列表 def parse_description(desc: str) -> List[str]: return [d.strip() for d in desc.split(',') if d.strip()] if desc else [] # 获取各类别描述 descriptions = parse_description(row.get('cls_description', '')) func_descriptions = parse_description(row.get('func_description', '')) # 构建类信息列表 classes = [] cls_names = row['cls'] cls_en_names = row['cls_en'] for idx in row['cls_index']: if idx >= len(cls_names): continue class_info = ModelClassInfo( index=idx, name=cls_names[idx], english_name=cls_en_names[idx] if idx < len(cls_en_names) else None, ) # 添加描述(如果存在) if descriptions and idx < len(descriptions): class_info.description = descriptions[idx] elif func_descriptions and idx < len(func_descriptions): class_info.description = func_descriptions[idx % len(func_descriptions)] classes.append(class_info) # text = "识别内容=0:鳄鱼纹裂缝;1:纵向裂缝;2:斜裂缝;3:坑洞;4:修补;5:横向裂缝" text = row['py_func'] # 提取所有数字和描述 matches = re.findall(r'(\d+):([^;]+)', text) # 转换为字典(数字为键,描述为值) # result = {int(num): desc for num, desc in matches} # {0: '鳄鱼纹裂缝', 1: '纵向裂缝', 2: '斜裂缝', 3: '坑洞', 4: '修补', 5: '横向裂缝'} filter_indices = [int(num) for num, desc in matches] if filter_cls: filter_indices = filter_cls # filter_indices = row['filter_cls'] cls_index = row['cls_index'] cls_en = row['cls_en'] cls = row['cls'] conf = row['confidence'] repeat_dis = row["repeat_dis"] # 校验需要过滤的下标,是否在总的下标内 valid_filter_cls = [idx for idx in filter_indices if idx in cls_index] # {0:'人’1:'人',2:'自行车’3:"汽车’,4:'厢型车’5:'卡车,6:'三轮车''三轮车’8:公交,9:摩托” filtered_cls_en_dict = { idx: cls_en[cls_index.index(idx)] for idx in valid_filter_cls } cls_en_dict = { idx: cls_en[cls_index.index(idx)] for idx in cls_index } # {0: 'pedestrian', 1: 'people', 2: "bicycle', 3: 'car', 4: 'van', 5: "truck', 6:ricycle',7:'awning-tricycle', 8: 'bus', 9: 'motor'} filtered_cls_dict = { idx: cls[cls_index.index(idx)] for idx in valid_filter_cls } cls_dict = { idx: cls[cls_index.index(idx)] for idx in cls_index } cls_str_dict = {idx: name for idx, name in zip(cls, cls_en)} valid_pairs = [] for idx in filter_indices: if idx in cls_index: en_name = cls_en[cls_index.index(idx)] zh_name = cls[cls_index.index(idx)] valid_pairs.append((zh_name, en_name)) cls_zn_to_eh_dict = {} for en_name, zh_name in valid_pairs: if zh_name not in cls_zn_to_eh_dict.values(): # 如果中文名还未在字典的值中,直接添加 cls_zn_to_eh_dict[en_name] = zh_name allowed_classes = [cls[i] for i in filter_indices] # 构建完整的配置对象 return ModelData( id=row['model_id'], yolo_version=row['yolo_version'], model_path=row['path'], engine_path=row['engine_path'], so_path=row['so_path'], func_description=row.get('scope'), # filter_indices=row['filter_cls'], filter_indices=filter_indices, repeat_dis=repeat_dis, repeat_time=row.get('repeat_time'), class_indices=row['cls_index'], conf=conf, classes=classes, total_classes=len(classes), created_at=row.get('created_at'), updated_at=row.get('updated_at'), cls_names=cls_names, filtered_cls_en_dict=filtered_cls_en_dict, cls_en_dict=cls_en_dict, filtered_cls_dict=filtered_cls_dict, cls_dict=cls_dict, cls_str_dict=cls_str_dict, cls_zn_to_eh_dict=cls_zn_to_eh_dict, allowed_classes=allowed_classes ) # 示例数据 sample_data = ModelConfiguration( model_info=ModelInfo( id=100002, yolo_version="11", model_path="pt/GDCL.pt", func_description="挖掘机,自卸卡车,压路机,移动式起重机,固定式起重机,轮式装载机,混凝土搅拌车,反铲式装载机,推土机,平地机" ), class_config=ClassConfig( filter_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], class_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], classes=[ ModelClassInfo(index=0, name="Excavator", english_name="挖掘机", description="挖掘机"), ModelClassInfo(index=1, name="Flatbed", english_name="卡车"), ModelClassInfo(index=2, name="Truck", english_name="卡车"), ModelClassInfo(index=3, name="DumpTruck", english_name="卡车"), ModelClassInfo(index=4, name="Roller", english_name="压路机"), ModelClassInfo(index=5, name="CraneMobile", english_name="起重机"), ModelClassInfo(index=6, name="CraneTruck", english_name="起重机"), ModelClassInfo(index=7, name="CraneTower", english_name="塔吊"), ModelClassInfo(index=8, name="CraneFixed", english_name="塔吊"), ModelClassInfo(index=9, name="LoaderWheel", english_name="塔吊"), ModelClassInfo(index=10, name="CementTanker", english_name="水泥罐车"), ModelClassInfo(index=11, name="ConcreteMixer", english_name="水泥罐车"), ModelClassInfo(index=12, name="MixerTruck", english_name="水泥罐车"), ModelClassInfo(index=13, name="BackhoeLoader", english_name="挖土机"), ModelClassInfo(index=14, name="LoaderWheel", english_name="推土机"), ModelClassInfo(index=15, name="Grader", english_name="地坪机"), ModelClassInfo(index=16, name="PileDriver", english_name="打桩机") ] ), metadata=ModelMetadata(total_classes=17) ) # 示例使用 if __name__ == "__main__": # 数据库连接配置(请替换为实际参数) DB_CONFIG = { "dbname": "smart_dev_123", "user": "postgres", "password": "root", "host": "8.137.54.85", "port": "5060" } # 创建DAO实例 dao = ModelConfigDAO(DB_CONFIG) # # 示例1: 插入新配置 # print("插入新配置...") # if dao.insert_config(sample_data): # print("插入成功") # else: # print("插入失败") # 示例2: 查询配置 print("\n查询配置...") model_id = 100091 config = dao.get_config(model_id) if config: # 打印结构化结果(使用自定义编码器处理datetime) print(json.dumps(asdict(config), indent=2, ensure_ascii=False, cls=DateTimeEncoder)) # 访问特定字段 print("\n模型路径:", config.model_path) print("过滤类别:", config.filter_indices) print("第一个类别:", asdict(config.classes[0])) print("创建时间:", config.created_at) print("更新时间:", config.updated_at) else: print(f"未找到ID为 {model_id} 的模型配置") # # # 示例3: 更新配置 # print("\n更新配置...") # if config: # config.model_info.yolo_version = "12" # if dao.update_config(config): # print("更新成功") # # 验证更新 # updated_config = dao.get_config(model_id) # print("更新后的YOLO版本:", updated_config.model_info.yolo_version) # print("更新时间:", updated_config.metadata.updated_at) # else: # print("更新失败")