388 lines
11 KiB
Python
388 lines
11 KiB
Python
import psycopg2
|
||
from psycopg2.extras import RealDictCursor
|
||
import json
|
||
from typing import Dict, List, Union, Optional
|
||
from dataclasses import dataclass, asdict
|
||
from datetime import datetime
|
||
import re
|
||
|
||
|
||
@dataclass
|
||
class ModelClassInfo:
|
||
index: int
|
||
name: str
|
||
english_name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class ClassConfig:
|
||
filter_indices: List[int]
|
||
class_indices: List[int]
|
||
classes: List[ModelClassInfo]
|
||
|
||
|
||
@dataclass
|
||
class ModelInfo:
|
||
id: int
|
||
yolo_version: str
|
||
model_path: str
|
||
func_description: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class ModelMetadata:
|
||
total_classes: int
|
||
created_at: Optional[datetime] = None
|
||
updated_at: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class ModelData:
|
||
id: int
|
||
yolo_version: str
|
||
model_path: str
|
||
repeat_dis: float
|
||
func_description: Optional[str]
|
||
filter_indices: List[int]
|
||
class_indices: List[int]
|
||
conf: float
|
||
classes: List[ModelClassInfo]
|
||
total_classes: int
|
||
cls_names: {}
|
||
filtered_cls_en_dict: {}
|
||
cls_en_dict: {}
|
||
filtered_cls_dict: {}
|
||
cls_dict: {}
|
||
cls_str_dict: {}
|
||
cls_zn_to_eh_dict: {}
|
||
allowed_classes: []
|
||
created_at: Optional[datetime] = None
|
||
updated_at: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class MqttData:
|
||
mqtt_id: int
|
||
mqtt_ip: str
|
||
mqtt_port: int
|
||
mqtt_topic: str
|
||
mqtt_username: str
|
||
mqtt_pass: str
|
||
mqtt_description: str
|
||
org_code: str
|
||
mqtt_type: str
|
||
|
||
|
||
@dataclass
|
||
class Device:
|
||
dname: str
|
||
sn: str
|
||
orgcode: int
|
||
lat: float
|
||
lng: float
|
||
height: float
|
||
|
||
|
||
@dataclass
|
||
class ModelConfiguration:
|
||
model_info: ModelInfo
|
||
class_config: ClassConfig
|
||
metadata: ModelMetadata
|
||
|
||
|
||
@dataclass
|
||
class Dataset:
|
||
id: int
|
||
resource_original_path: str
|
||
pic_name: str
|
||
local_path: str
|
||
label_name: str
|
||
label_content: str
|
||
label_txt_path: str
|
||
|
||
|
||
@dataclass
|
||
class Labels:
|
||
id: int
|
||
resource_original_path: str
|
||
resource_id: int
|
||
label_set_id: int
|
||
label_ids: int
|
||
annotation_data: str
|
||
|
||
|
||
@dataclass
|
||
class Label_Yaml:
|
||
id: int
|
||
id_order:int
|
||
name: str
|
||
e_name: str
|
||
|
||
|
||
class DateTimeEncoder(json.JSONEncoder):
|
||
"""自定义JSON编码器,用于处理datetime对象"""
|
||
|
||
def default(self, obj):
|
||
if isinstance(obj, datetime):
|
||
return obj.isoformat()
|
||
return super().default(obj)
|
||
|
||
|
||
class ModelConfigDAO:
|
||
def __init__(self, db_params: Dict[str, str]):
|
||
"""
|
||
初始化数据库连接
|
||
参数:
|
||
db_params: 数据库连接参数,包含:
|
||
- dbname: 数据库名
|
||
- user: 用户名
|
||
- password: 密码
|
||
- host: 主机地址
|
||
- port: 端口号
|
||
"""
|
||
self.db_params = db_params
|
||
|
||
def insert_config(self, config: ModelConfiguration) -> bool:
|
||
"""
|
||
插入新的模型配置
|
||
|
||
参数:
|
||
config: 要插入的模型配置对象
|
||
|
||
返回:
|
||
是否插入成功
|
||
"""
|
||
if not isinstance(config, ModelConfiguration):
|
||
raise ValueError("Invalid configuration type")
|
||
|
||
# 将对象转换为数据库格式
|
||
data = self._config_to_db_format(config)
|
||
|
||
query = """
|
||
INSERT INTO ai_model (
|
||
model_id, filter_cls, func_description,
|
||
yolo_version, path, cls_index, cls, cls_en, cls_description
|
||
) VALUES (
|
||
%(model_id)s, %(filter_cls)s, %(func_description)s,
|
||
%(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s
|
||
)
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(query, data)
|
||
conn.commit()
|
||
return True
|
||
except psycopg2.Error as e:
|
||
print(f"Database insert error: {e}")
|
||
return False
|
||
|
||
def update_config(self, config: ModelConfiguration) -> bool:
|
||
"""
|
||
更新现有的模型配置
|
||
|
||
参数:
|
||
config: 要更新的模型配置对象
|
||
|
||
返回:
|
||
是否更新成功
|
||
"""
|
||
if not isinstance(config, ModelConfiguration):
|
||
raise ValueError("Invalid configuration type")
|
||
|
||
data = self._config_to_db_format(config)
|
||
|
||
query = """
|
||
UPDATE ai_model SET
|
||
filter_cls = %(filter_cls)s,
|
||
func_description = %(func_description)s,
|
||
yolo_version = %(yolo_version)s,
|
||
path = %(path)s,
|
||
cls_index = %(cls_index)s,
|
||
cls = %(cls)s,
|
||
cls_en = %(cls_en)s,
|
||
cls_description = %(cls_description)s
|
||
WHERE model_id = %(model_id)s
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(query, data)
|
||
conn.commit()
|
||
return True
|
||
except psycopg2.Error as e:
|
||
print(f"Database update error: {e}")
|
||
return False
|
||
|
||
def get_datasets(self, bz_training_task_id: int) -> List[Dataset]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select bpra.id,bpra.resource_original_path from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
|
||
left join bz_datasets c on c.id =a.datasetid
|
||
left join bz_dataset_project_relations d on d.data_set_id =c.id
|
||
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
|
||
where b.id=%s
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (bz_training_task_id,))
|
||
rows = cur.fetchall()
|
||
return [self._db_row_to_dataset(row) for row in rows] # 转换为Dataset列表
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
def get_labels(self, bz_training_task_id: int) -> List[Labels]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select bpra.id,bpra.resource_original_path ,bar.resource_id , bar.label_set_id ,bar.label_id ,bar.annotation_data from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
|
||
left join bz_datasets c on c.id =a.datasetid
|
||
left join bz_dataset_project_relations d on d.data_set_id =c.id
|
||
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
|
||
left join bz_annotation_record bar on bar.task_assignment_id =bpra.id
|
||
where b.id=%s
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (bz_training_task_id,))
|
||
rows = cur.fetchall()
|
||
return [self._db_row_to_labels(row) for row in rows] # 转换为Dataset列表
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
def get_label_yaml(self, bz_training_task_id: int) -> List[Label_Yaml]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select id,name,e_name from bz_labels where id in (select distinct(bar.label_id) AS id from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
|
||
left join bz_datasets c on c.id =a.datasetid
|
||
left join bz_dataset_project_relations d on d.data_set_id =c.id
|
||
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
|
||
left join bz_annotation_record bar on bar.task_assignment_id =bpra.id
|
||
where b.id=%s
|
||
)
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (bz_training_task_id,))
|
||
rows = cur.fetchall()
|
||
return [self._db_row_to_label_yaml(row) for row in rows] # 转换为Dataset列表
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
def _config_to_db_format(self, config: ModelConfiguration) -> Dict:
|
||
"""将配置对象转换为数据库格式"""
|
||
return {
|
||
"model_id": config.model_info.id,
|
||
"filter_cls": config.class_config.filter_indices,
|
||
"func_description": config.model_info.func_description,
|
||
"yolo_version": config.model_info.yolo_version,
|
||
"path": config.model_info.model_path,
|
||
"cls_index": config.class_config.class_indices,
|
||
"cls": [cls_info.name for cls_info in config.class_config.classes],
|
||
"cls_en": [cls_info.english_name for cls_info in config.class_config.classes],
|
||
"cls_description": ", ".join(
|
||
filter(None, [cls_info.description for cls_info in config.class_config.classes])
|
||
)
|
||
}
|
||
|
||
def insert_train_pid(self, task_id, train_pid) -> bool:
|
||
"""
|
||
插入新的训练记录(task_id 和 train_pid)
|
||
|
||
参数:
|
||
config: 要插入的模型配置对象(需包含 task_id 和 train_pid)
|
||
|
||
返回:
|
||
是否插入成功
|
||
"""
|
||
|
||
insert_sql = """
|
||
INSERT INTO bz_train_record (
|
||
task_id, train_pid,create_time
|
||
) VALUES (
|
||
%s, %s,now()
|
||
)
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(insert_sql, (task_id, train_pid))
|
||
conn.commit()
|
||
return True
|
||
except psycopg2.Error as e:
|
||
print(f"Database insert error: {e}")
|
||
return False
|
||
|
||
def _db_row_to_dataset(self, row: Dict) -> ModelConfiguration:
|
||
|
||
return Dataset(
|
||
id=row["id"],
|
||
resource_original_path=row["resource_original_path"],
|
||
pic_name=None,
|
||
local_path=None,
|
||
label_name=None,
|
||
label_content="",
|
||
label_txt_path=None
|
||
)
|
||
|
||
def _db_row_to_labels(self, row: Dict) -> ModelConfiguration:
|
||
|
||
return Labels(
|
||
id=row["id"],
|
||
resource_original_path=row["resource_original_path"],
|
||
resource_id=row["resource_id"],
|
||
label_set_id=row["label_set_id"],
|
||
label_ids=row["label_id"],
|
||
annotation_data=row["annotation_data"]
|
||
)
|
||
|
||
def _db_row_to_label_yaml(self, row: Dict) -> ModelConfiguration:
|
||
|
||
return Label_Yaml(
|
||
id=row["id"],
|
||
id_order=-1,
|
||
name=row["name"],
|
||
e_name=row["e_name"]
|
||
|
||
)
|
||
|