ai_project_v1/middleware/query_model.py

676 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import psycopg2
from psycopg2.extras import RealDictCursor
import json
from typing import Dict, List, Union, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import re
@dataclass
class ModelClassInfo:
index: int
name: str
english_name: Optional[str] = None
description: Optional[str] = None
@dataclass
class ClassConfig:
filter_indices: List[int]
class_indices: List[int]
classes: List[ModelClassInfo]
@dataclass
class ModelInfo:
id: int
yolo_version: str
model_path: str
func_description: Optional[str] = None
@dataclass
class ModelMetadata:
total_classes: int
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class ModelData:
id: int
yolo_version: str
model_path: str
engine_path: str
so_path: str
repeat_dis: float
repeat_time: float
func_description: Optional[str]
filter_indices: List[int]
class_indices: List[int]
conf: float
classes: List[ModelClassInfo]
total_classes: int
cls_names: {}
filtered_cls_en_dict: {}
cls_en_dict: {}
filtered_cls_dict: {}
cls_dict: {}
cls_str_dict: {}
cls_zn_to_eh_dict: {}
allowed_classes: []
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class MqttData:
mqtt_id: int
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
mqtt_username: str
mqtt_pass: str
mqtt_description: str
org_code: str
mqtt_type: str
@dataclass
class Device:
dname: str
sn: str
orgcode: int
lat: float
lng: float
height: float
@dataclass
class ModelConfiguration:
model_info: ModelInfo
class_config: ClassConfig
metadata: ModelMetadata
class DateTimeEncoder(json.JSONEncoder):
"""自定义JSON编码器用于处理datetime对象"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
class ModelConfigDAO:
def __init__(self, db_params: Dict[str, str]):
"""
初始化数据库连接
参数:
db_params: 数据库连接参数,包含:
- dbname: 数据库名
- user: 用户名
- password: 密码
- host: 主机地址
- port: 端口号
"""
self.db_params = db_params
self._create_table_if_not_exists()
def _create_table_if_not_exists(self):
"""确保表存在"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS ai_model (
model_id INTEGER PRIMARY KEY,
filter_cls INTEGER[],
func_description TEXT,
yolo_version TEXT ,
path TEXT ,
cls_index INTEGER[] ,
cls TEXT[] ,
cls_en TEXT[] ,
cls_description TEXT,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
DROP TRIGGER IF EXISTS update_model_config_modtime ON ai_model;
CREATE TRIGGER update_model_config_modtime
BEFORE UPDATE ON ai_model
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(create_table_sql)
conn.commit()
except psycopg2.Error as e:
print(f"Error creating table: {e}")
raise
def insert_config(self, config: ModelConfiguration) -> bool:
"""
插入新的模型配置
参数:
config: 要插入的模型配置对象
返回:
是否插入成功
"""
if not isinstance(config, ModelConfiguration):
raise ValueError("Invalid configuration type")
# 将对象转换为数据库格式
data = self._config_to_db_format(config)
query = """
INSERT INTO ai_model (
model_id, filter_cls, func_description,
yolo_version, path, cls_index, cls, cls_en, cls_description
) VALUES (
%(model_id)s, %(filter_cls)s, %(func_description)s,
%(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s
)
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(query, data)
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database insert error: {e}")
return False
def update_config(self, config: ModelConfiguration) -> bool:
"""
更新现有的模型配置
参数:
config: 要更新的模型配置对象
返回:
是否更新成功
"""
if not isinstance(config, ModelConfiguration):
raise ValueError("Invalid configuration type")
data = self._config_to_db_format(config)
query = """
UPDATE ai_model SET
filter_cls = %(filter_cls)s,
func_description = %(func_description)s,
yolo_version = %(yolo_version)s,
path = %(path)s,
cls_index = %(cls_index)s,
cls = %(cls)s,
cls_en = %(cls_en)s,
cls_description = %(cls_description)s
WHERE model_id = %(model_id)s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(query, data)
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database update error: {e}")
return False
def get_config(self, model_id: int, filter_indices: []) -> Optional[ModelConfiguration]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
SELECT
aml.func_id,
aml.model_func_id,
aml.model_id,
aml.confidence,
aml.py_func,
aml.repeat_dis,
aml.repeat_time,
am.scope,
am.yolo_version,
am.PATH,
am.engine_path,
am.so_path,
am.cls_index,
am.cls,
am.cls_en,
am.cls_description,
am.created_at,
am.updated_at
FROM
ai_model am,
ai_model_list aml
WHERE
aml.func_id = %s
AND aml.model_id = am.model_id
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (model_id,))
result = cur.fetchone()
return self._db_row_to_config(result, filter_indices) if result else None
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def insert_request_log(self, task_id, sn, org_code, requset_json, request) -> bool:
"""
插入新的模型配置
参数:
config: 要插入的模型配置对象
返回:
是否插入成功
"""
query = """
insert into ai_request_log(task_id,sn,org_code,requset_json,request,create_time)values(%s,%s,%s,%s,%s,now())
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(query, (task_id, sn, org_code, requset_json, request,))
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database insert error: {e}")
return False
def get_device(self, sn: str, orgcode: str) -> Optional[Device]:
"""
获取并解析模型配置
参数:
sn: 机场sn
orgcode: 组织id
返回:
结构化的模型配置或None如果未找到
"""
query = """
select dname,sn,orgcode,lat,lng,height from device where sn=%s and orgcode=%s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (sn, orgcode,))
result = cur.fetchone()
return self._db_row_to_device_config(result) if result else None
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def get_mqtt_config(self, mqtt_id: int) -> Optional[ModelConfiguration]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
select id,mqtt_ip,mqtt_port,mqtt_topic,mqtt_username,mqtt_pass,description from ai_mqtt_config where id=%s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (mqtt_id,))
result = cur.fetchone()
return self._db_row_to_mqtt_config(result) if result else None
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def get_pic_mqtt_config_by_orgcode(self, org_code: str,mqtt_type: str) -> Optional[MqttData]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
select id,org_code,mqtt_type,mqtt_ip,mqtt_port,
mqtt_topic,mqtt_username,mqtt_pass,description
from ai_mqtt_config
where org_code=%s and mqtt_type=%s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (org_code,mqtt_type,))
result = cur.fetchone()
return self._db_row_to_mqtt_config_by_orgcode(result) if result else None
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def get_mqtt_config_by_orgcode(self, org_code: str, sn: str,mqtt_type: str) -> Optional[MqttData]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
select id,org_code,mqtt_type,mqtt_ip,mqtt_port,
mqtt_topic,mqtt_username,mqtt_pass,description
from ai_mqtt_config
where org_code=%s and sn=%s and mqtt_type=%s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (org_code, sn,mqtt_type,))
result = cur.fetchone()
return self._db_row_to_mqtt_config_by_orgcode(result) if result else None
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def _config_to_db_format(self, config: ModelConfiguration) -> Dict:
"""将配置对象转换为数据库格式"""
return {
"model_id": config.model_info.id,
"filter_cls": config.class_config.filter_indices,
"func_description": config.model_info.func_description,
"yolo_version": config.model_info.yolo_version,
"path": config.model_info.model_path,
"cls_index": config.class_config.class_indices,
"cls": [cls_info.name for cls_info in config.class_config.classes],
"cls_en": [cls_info.english_name for cls_info in config.class_config.classes],
"cls_description": ", ".join(
filter(None, [cls_info.description for cls_info in config.class_config.classes])
)
}
def _db_row_to_device_config(self, row: Dict) -> ModelConfiguration:
return Device(
# dname, sn, orgcode, lat, lng, height
dname=row["dname"],
sn=row["sn"],
orgcode=row["orgcode"],
lat=float(row["lat"]),
lng=float(row["lng"]),
height=float(row["height"])
)
def _db_row_to_mqtt_config(self, row: Dict) -> ModelConfiguration:
return MqttData(
mqtt_id=row["id"],
mqtt_ip=row["mqtt_ip"],
mqtt_port=int(row["mqtt_port"]),
mqtt_topic=row["mqtt_topic"],
mqtt_username=row["mqtt_username"],
mqtt_pass=row["mqtt_pass"],
mqtt_description=row["description"]
)
def _db_row_to_mqtt_config_by_orgcode(self, row: Dict) -> ModelConfiguration:
return MqttData(
mqtt_id=row["id"],
mqtt_ip=row["mqtt_ip"],
mqtt_port=int(row["mqtt_port"]),
mqtt_topic=row["mqtt_topic"],
mqtt_username=row["mqtt_username"],
mqtt_pass=row["mqtt_pass"],
mqtt_description=row["description"],
org_code=row["org_code"],
mqtt_type=row["mqtt_type"]
)
def _db_row_to_config(self, row: Dict, filter_cls: []) -> ModelConfiguration:
"""将数据库行转换为配置对象"""
# 解析描述字符串为列表
def parse_description(desc: str) -> List[str]:
return [d.strip() for d in desc.split(',') if d.strip()] if desc else []
# 获取各类别描述
descriptions = parse_description(row.get('cls_description', ''))
func_descriptions = parse_description(row.get('func_description', ''))
# 构建类信息列表
classes = []
cls_names = row['cls']
cls_en_names = row['cls_en']
for idx in row['cls_index']:
if idx >= len(cls_names):
continue
class_info = ModelClassInfo(
index=idx,
name=cls_names[idx],
english_name=cls_en_names[idx] if idx < len(cls_en_names) else None,
)
# 添加描述(如果存在)
if descriptions and idx < len(descriptions):
class_info.description = descriptions[idx]
elif func_descriptions and idx < len(func_descriptions):
class_info.description = func_descriptions[idx % len(func_descriptions)]
classes.append(class_info)
# text = "识别内容=0:鳄鱼纹裂缝;1:纵向裂缝;2:斜裂缝;3:坑洞;4:修补;5:横向裂缝"
text = row['py_func']
# 提取所有数字和描述
matches = re.findall(r'(\d+):([^;]+)', text)
# 转换为字典(数字为键,描述为值)
# result = {int(num): desc for num, desc in matches}
# {0: '鳄鱼纹裂缝', 1: '纵向裂缝', 2: '斜裂缝', 3: '坑洞', 4: '修补', 5: '横向裂缝'}
filter_indices = [int(num) for num, desc in matches]
if filter_cls:
filter_indices = filter_cls
# filter_indices = row['filter_cls']
cls_index = row['cls_index']
cls_en = row['cls_en']
cls = row['cls']
conf = row['confidence']
repeat_dis = row["repeat_dis"]
# 校验需要过滤的下标,是否在总的下标内
valid_filter_cls = [idx for idx in filter_indices if idx in cls_index]
# {0:'人1:'人',2:'自行车3:"汽车’,4:'厢型车5:'卡车,6:'三轮车''三轮车8:公交9:摩托”
filtered_cls_en_dict = {
idx: cls_en[cls_index.index(idx)]
for idx in valid_filter_cls
}
cls_en_dict = {
idx: cls_en[cls_index.index(idx)]
for idx in cls_index
}
# {0: 'pedestrian', 1: 'people', 2: "bicycle', 3: 'car', 4: 'van', 5: "truck', 6:ricycle',7:'awning-tricycle', 8: 'bus', 9: 'motor'}
filtered_cls_dict = {
idx: cls[cls_index.index(idx)]
for idx in valid_filter_cls
}
cls_dict = {
idx: cls[cls_index.index(idx)]
for idx in cls_index
}
cls_str_dict = {idx: name for idx, name in zip(cls, cls_en)}
valid_pairs = []
for idx in filter_indices:
if idx in cls_index:
en_name = cls_en[cls_index.index(idx)]
zh_name = cls[cls_index.index(idx)]
valid_pairs.append((zh_name, en_name))
cls_zn_to_eh_dict = {}
for en_name, zh_name in valid_pairs:
if zh_name not in cls_zn_to_eh_dict.values():
# 如果中文名还未在字典的值中,直接添加
cls_zn_to_eh_dict[en_name] = zh_name
allowed_classes = [cls[i] for i in filter_indices]
# 构建完整的配置对象
return ModelData(
id=row['model_id'],
yolo_version=row['yolo_version'],
model_path=row['path'],
engine_path=row['engine_path'],
so_path=row['so_path'],
func_description=row.get('scope'),
# filter_indices=row['filter_cls'],
filter_indices=filter_indices,
repeat_dis=repeat_dis,
repeat_time=row.get('repeat_time'),
class_indices=row['cls_index'],
conf=conf,
classes=classes,
total_classes=len(classes),
created_at=row.get('created_at'),
updated_at=row.get('updated_at'),
cls_names=cls_names,
filtered_cls_en_dict=filtered_cls_en_dict,
cls_en_dict=cls_en_dict,
filtered_cls_dict=filtered_cls_dict,
cls_dict=cls_dict,
cls_str_dict=cls_str_dict,
cls_zn_to_eh_dict=cls_zn_to_eh_dict,
allowed_classes=allowed_classes
)
# 示例数据
sample_data = ModelConfiguration(
model_info=ModelInfo(
id=100002,
yolo_version="11",
model_path="pt/GDCL.pt",
func_description="挖掘机,自卸卡车,压路机,移动式起重机,固定式起重机,轮式装载机,混凝土搅拌车,反铲式装载机,推土机,平地机"
),
class_config=ClassConfig(
filter_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
class_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
classes=[
ModelClassInfo(index=0, name="Excavator", english_name="挖掘机", description="挖掘机"),
ModelClassInfo(index=1, name="Flatbed", english_name="卡车"),
ModelClassInfo(index=2, name="Truck", english_name="卡车"),
ModelClassInfo(index=3, name="DumpTruck", english_name="卡车"),
ModelClassInfo(index=4, name="Roller", english_name="压路机"),
ModelClassInfo(index=5, name="CraneMobile", english_name="起重机"),
ModelClassInfo(index=6, name="CraneTruck", english_name="起重机"),
ModelClassInfo(index=7, name="CraneTower", english_name="塔吊"),
ModelClassInfo(index=8, name="CraneFixed", english_name="塔吊"),
ModelClassInfo(index=9, name="LoaderWheel", english_name="塔吊"),
ModelClassInfo(index=10, name="CementTanker", english_name="水泥罐车"),
ModelClassInfo(index=11, name="ConcreteMixer", english_name="水泥罐车"),
ModelClassInfo(index=12, name="MixerTruck", english_name="水泥罐车"),
ModelClassInfo(index=13, name="BackhoeLoader", english_name="挖土机"),
ModelClassInfo(index=14, name="LoaderWheel", english_name="推土机"),
ModelClassInfo(index=15, name="Grader", english_name="地坪机"),
ModelClassInfo(index=16, name="PileDriver", english_name="打桩机")
]
),
metadata=ModelMetadata(total_classes=17)
)
# 示例使用
if __name__ == "__main__":
# 数据库连接配置(请替换为实际参数)
DB_CONFIG = {
"dbname": "smart_dev_123",
"user": "postgres",
"password": "root",
"host": "8.137.54.85",
"port": "5060"
}
# 创建DAO实例
dao = ModelConfigDAO(DB_CONFIG)
# # 示例1: 插入新配置
# print("插入新配置...")
# if dao.insert_config(sample_data):
# print("插入成功")
# else:
# print("插入失败")
# 示例2: 查询配置
print("\n查询配置...")
model_id = 100091
config = dao.get_config(model_id)
if config:
# 打印结构化结果使用自定义编码器处理datetime
print(json.dumps(asdict(config), indent=2, ensure_ascii=False, cls=DateTimeEncoder))
# 访问特定字段
print("\n模型路径:", config.model_path)
print("过滤类别:", config.filter_indices)
print("第一个类别:", asdict(config.classes[0]))
print("创建时间:", config.created_at)
print("更新时间:", config.updated_at)
else:
print(f"未找到ID为 {model_id} 的模型配置")
#
# # 示例3: 更新配置
# print("\n更新配置...")
# if config:
# config.model_info.yolo_version = "12"
# if dao.update_config(config):
# print("更新成功")
# # 验证更新
# updated_config = dao.get_config(model_id)
# print("更新后的YOLO版本:", updated_config.model_info.yolo_version)
# print("更新时间:", updated_config.metadata.updated_at)
# else:
# print("更新失败")