388 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import psycopg2
from psycopg2.extras import RealDictCursor
import json
from typing import Dict, List, Union, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import re
@dataclass
class ModelClassInfo:
index: int
name: str
english_name: Optional[str] = None
description: Optional[str] = None
@dataclass
class ClassConfig:
filter_indices: List[int]
class_indices: List[int]
classes: List[ModelClassInfo]
@dataclass
class ModelInfo:
id: int
yolo_version: str
model_path: str
func_description: Optional[str] = None
@dataclass
class ModelMetadata:
total_classes: int
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class ModelData:
id: int
yolo_version: str
model_path: str
repeat_dis: float
func_description: Optional[str]
filter_indices: List[int]
class_indices: List[int]
conf: float
classes: List[ModelClassInfo]
total_classes: int
cls_names: {}
filtered_cls_en_dict: {}
cls_en_dict: {}
filtered_cls_dict: {}
cls_dict: {}
cls_str_dict: {}
cls_zn_to_eh_dict: {}
allowed_classes: []
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class MqttData:
mqtt_id: int
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
mqtt_username: str
mqtt_pass: str
mqtt_description: str
org_code: str
mqtt_type: str
@dataclass
class Device:
dname: str
sn: str
orgcode: int
lat: float
lng: float
height: float
@dataclass
class ModelConfiguration:
model_info: ModelInfo
class_config: ClassConfig
metadata: ModelMetadata
@dataclass
class Dataset:
id: int
resource_original_path: str
pic_name: str
local_path: str
label_name: str
label_content: str
label_txt_path: str
@dataclass
class Labels:
id: int
resource_original_path: str
resource_id: int
label_set_id: int
label_ids: int
annotation_data: str
@dataclass
class Label_Yaml:
id: int
id_order:int
name: str
e_name: str
class DateTimeEncoder(json.JSONEncoder):
"""自定义JSON编码器用于处理datetime对象"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
class ModelConfigDAO:
def __init__(self, db_params: Dict[str, str]):
"""
初始化数据库连接
参数:
db_params: 数据库连接参数,包含:
- dbname: 数据库名
- user: 用户名
- password: 密码
- host: 主机地址
- port: 端口号
"""
self.db_params = db_params
def insert_config(self, config: ModelConfiguration) -> bool:
"""
插入新的模型配置
参数:
config: 要插入的模型配置对象
返回:
是否插入成功
"""
if not isinstance(config, ModelConfiguration):
raise ValueError("Invalid configuration type")
# 将对象转换为数据库格式
data = self._config_to_db_format(config)
query = """
INSERT INTO ai_model (
model_id, filter_cls, func_description,
yolo_version, path, cls_index, cls, cls_en, cls_description
) VALUES (
%(model_id)s, %(filter_cls)s, %(func_description)s,
%(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s
)
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(query, data)
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database insert error: {e}")
return False
def update_config(self, config: ModelConfiguration) -> bool:
"""
更新现有的模型配置
参数:
config: 要更新的模型配置对象
返回:
是否更新成功
"""
if not isinstance(config, ModelConfiguration):
raise ValueError("Invalid configuration type")
data = self._config_to_db_format(config)
query = """
UPDATE ai_model SET
filter_cls = %(filter_cls)s,
func_description = %(func_description)s,
yolo_version = %(yolo_version)s,
path = %(path)s,
cls_index = %(cls_index)s,
cls = %(cls)s,
cls_en = %(cls_en)s,
cls_description = %(cls_description)s
WHERE model_id = %(model_id)s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(query, data)
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database update error: {e}")
return False
def get_datasets(self, bz_training_task_id: int) -> List[Dataset]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
select bpra.id,bpra.resource_original_path from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
left join bz_datasets c on c.id =a.datasetid
left join bz_dataset_project_relations d on d.data_set_id =c.id
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
where b.id=%s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (bz_training_task_id,))
rows = cur.fetchall()
return [self._db_row_to_dataset(row) for row in rows] # 转换为Dataset列表
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def get_labels(self, bz_training_task_id: int) -> List[Labels]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
select bpra.id,bpra.resource_original_path ,bar.resource_id , bar.label_set_id ,bar.label_id ,bar.annotation_data from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
left join bz_datasets c on c.id =a.datasetid
left join bz_dataset_project_relations d on d.data_set_id =c.id
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
left join bz_annotation_record bar on bar.task_assignment_id =bpra.id
where b.id=%s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (bz_training_task_id,))
rows = cur.fetchall()
return [self._db_row_to_labels(row) for row in rows] # 转换为Dataset列表
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def get_label_yaml(self, bz_training_task_id: int) -> List[Label_Yaml]:
"""
获取并解析模型配置
参数:
model_id: 模型功能ID
返回:
结构化的模型配置或None如果未找到
"""
query = """
select id,name,e_name from bz_labels where id in (select distinct(bar.label_id) AS id from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
left join bz_datasets c on c.id =a.datasetid
left join bz_dataset_project_relations d on d.data_set_id =c.id
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
left join bz_annotation_record bar on bar.task_assignment_id =bpra.id
where b.id=%s
)
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (bz_training_task_id,))
rows = cur.fetchall()
return [self._db_row_to_label_yaml(row) for row in rows] # 转换为Dataset列表
except psycopg2.Error as e:
print(f"Database query error: {e}")
return None
def _config_to_db_format(self, config: ModelConfiguration) -> Dict:
"""将配置对象转换为数据库格式"""
return {
"model_id": config.model_info.id,
"filter_cls": config.class_config.filter_indices,
"func_description": config.model_info.func_description,
"yolo_version": config.model_info.yolo_version,
"path": config.model_info.model_path,
"cls_index": config.class_config.class_indices,
"cls": [cls_info.name for cls_info in config.class_config.classes],
"cls_en": [cls_info.english_name for cls_info in config.class_config.classes],
"cls_description": ", ".join(
filter(None, [cls_info.description for cls_info in config.class_config.classes])
)
}
def insert_train_pid(self, task_id, train_pid) -> bool:
"""
插入新的训练记录task_id 和 train_pid
参数:
config: 要插入的模型配置对象(需包含 task_id 和 train_pid
返回:
是否插入成功
"""
insert_sql = """
INSERT INTO bz_train_record (
task_id, train_pid,create_time
) VALUES (
%s, %s,now()
)
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(insert_sql, (task_id, train_pid))
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database insert error: {e}")
return False
def _db_row_to_dataset(self, row: Dict) -> ModelConfiguration:
return Dataset(
id=row["id"],
resource_original_path=row["resource_original_path"],
pic_name=None,
local_path=None,
label_name=None,
label_content="",
label_txt_path=None
)
def _db_row_to_labels(self, row: Dict) -> ModelConfiguration:
return Labels(
id=row["id"],
resource_original_path=row["resource_original_path"],
resource_id=row["resource_id"],
label_set_id=row["label_set_id"],
label_ids=row["label_id"],
annotation_data=row["annotation_data"]
)
def _db_row_to_label_yaml(self, row: Dict) -> ModelConfiguration:
return Label_Yaml(
id=row["id"],
id_order=-1,
name=row["name"],
e_name=row["e_name"]
)