388 lines
11 KiB
Python
388 lines
11 KiB
Python
|
|
import psycopg2
|
|||
|
|
from psycopg2.extras import RealDictCursor
|
|||
|
|
import json
|
|||
|
|
from typing import Dict, List, Union, Optional
|
|||
|
|
from dataclasses import dataclass, asdict
|
|||
|
|
from datetime import datetime
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ModelClassInfo:
|
|||
|
|
index: int
|
|||
|
|
name: str
|
|||
|
|
english_name: Optional[str] = None
|
|||
|
|
description: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ClassConfig:
|
|||
|
|
filter_indices: List[int]
|
|||
|
|
class_indices: List[int]
|
|||
|
|
classes: List[ModelClassInfo]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ModelInfo:
|
|||
|
|
id: int
|
|||
|
|
yolo_version: str
|
|||
|
|
model_path: str
|
|||
|
|
func_description: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ModelMetadata:
|
|||
|
|
total_classes: int
|
|||
|
|
created_at: Optional[datetime] = None
|
|||
|
|
updated_at: Optional[datetime] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ModelData:
|
|||
|
|
id: int
|
|||
|
|
yolo_version: str
|
|||
|
|
model_path: str
|
|||
|
|
repeat_dis: float
|
|||
|
|
func_description: Optional[str]
|
|||
|
|
filter_indices: List[int]
|
|||
|
|
class_indices: List[int]
|
|||
|
|
conf: float
|
|||
|
|
classes: List[ModelClassInfo]
|
|||
|
|
total_classes: int
|
|||
|
|
cls_names: {}
|
|||
|
|
filtered_cls_en_dict: {}
|
|||
|
|
cls_en_dict: {}
|
|||
|
|
filtered_cls_dict: {}
|
|||
|
|
cls_dict: {}
|
|||
|
|
cls_str_dict: {}
|
|||
|
|
cls_zn_to_eh_dict: {}
|
|||
|
|
allowed_classes: []
|
|||
|
|
created_at: Optional[datetime] = None
|
|||
|
|
updated_at: Optional[datetime] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MqttData:
|
|||
|
|
mqtt_id: int
|
|||
|
|
mqtt_ip: str
|
|||
|
|
mqtt_port: int
|
|||
|
|
mqtt_topic: str
|
|||
|
|
mqtt_username: str
|
|||
|
|
mqtt_pass: str
|
|||
|
|
mqtt_description: str
|
|||
|
|
org_code: str
|
|||
|
|
mqtt_type: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Device:
|
|||
|
|
dname: str
|
|||
|
|
sn: str
|
|||
|
|
orgcode: int
|
|||
|
|
lat: float
|
|||
|
|
lng: float
|
|||
|
|
height: float
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ModelConfiguration:
|
|||
|
|
model_info: ModelInfo
|
|||
|
|
class_config: ClassConfig
|
|||
|
|
metadata: ModelMetadata
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Dataset:
|
|||
|
|
id: int
|
|||
|
|
resource_original_path: str
|
|||
|
|
pic_name: str
|
|||
|
|
local_path: str
|
|||
|
|
label_name: str
|
|||
|
|
label_content: str
|
|||
|
|
label_txt_path: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Labels:
|
|||
|
|
id: int
|
|||
|
|
resource_original_path: str
|
|||
|
|
resource_id: int
|
|||
|
|
label_set_id: int
|
|||
|
|
label_ids: int
|
|||
|
|
annotation_data: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Label_Yaml:
|
|||
|
|
id: int
|
|||
|
|
id_order:int
|
|||
|
|
name: str
|
|||
|
|
e_name: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DateTimeEncoder(json.JSONEncoder):
|
|||
|
|
"""自定义JSON编码器,用于处理datetime对象"""
|
|||
|
|
|
|||
|
|
def default(self, obj):
|
|||
|
|
if isinstance(obj, datetime):
|
|||
|
|
return obj.isoformat()
|
|||
|
|
return super().default(obj)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelConfigDAO:
|
|||
|
|
def __init__(self, db_params: Dict[str, str]):
|
|||
|
|
"""
|
|||
|
|
初始化数据库连接
|
|||
|
|
参数:
|
|||
|
|
db_params: 数据库连接参数,包含:
|
|||
|
|
- dbname: 数据库名
|
|||
|
|
- user: 用户名
|
|||
|
|
- password: 密码
|
|||
|
|
- host: 主机地址
|
|||
|
|
- port: 端口号
|
|||
|
|
"""
|
|||
|
|
self.db_params = db_params
|
|||
|
|
|
|||
|
|
def insert_config(self, config: ModelConfiguration) -> bool:
|
|||
|
|
"""
|
|||
|
|
插入新的模型配置
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
config: 要插入的模型配置对象
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
是否插入成功
|
|||
|
|
"""
|
|||
|
|
if not isinstance(config, ModelConfiguration):
|
|||
|
|
raise ValueError("Invalid configuration type")
|
|||
|
|
|
|||
|
|
# 将对象转换为数据库格式
|
|||
|
|
data = self._config_to_db_format(config)
|
|||
|
|
|
|||
|
|
query = """
|
|||
|
|
INSERT INTO ai_model (
|
|||
|
|
model_id, filter_cls, func_description,
|
|||
|
|
yolo_version, path, cls_index, cls, cls_en, cls_description
|
|||
|
|
) VALUES (
|
|||
|
|
%(model_id)s, %(filter_cls)s, %(func_description)s,
|
|||
|
|
%(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
cur.execute(query, data)
|
|||
|
|
conn.commit()
|
|||
|
|
return True
|
|||
|
|
except psycopg2.Error as e:
|
|||
|
|
print(f"Database insert error: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def update_config(self, config: ModelConfiguration) -> bool:
|
|||
|
|
"""
|
|||
|
|
更新现有的模型配置
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
config: 要更新的模型配置对象
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
是否更新成功
|
|||
|
|
"""
|
|||
|
|
if not isinstance(config, ModelConfiguration):
|
|||
|
|
raise ValueError("Invalid configuration type")
|
|||
|
|
|
|||
|
|
data = self._config_to_db_format(config)
|
|||
|
|
|
|||
|
|
query = """
|
|||
|
|
UPDATE ai_model SET
|
|||
|
|
filter_cls = %(filter_cls)s,
|
|||
|
|
func_description = %(func_description)s,
|
|||
|
|
yolo_version = %(yolo_version)s,
|
|||
|
|
path = %(path)s,
|
|||
|
|
cls_index = %(cls_index)s,
|
|||
|
|
cls = %(cls)s,
|
|||
|
|
cls_en = %(cls_en)s,
|
|||
|
|
cls_description = %(cls_description)s
|
|||
|
|
WHERE model_id = %(model_id)s
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
cur.execute(query, data)
|
|||
|
|
conn.commit()
|
|||
|
|
return True
|
|||
|
|
except psycopg2.Error as e:
|
|||
|
|
print(f"Database update error: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def get_datasets(self, bz_training_task_id: int) -> List[Dataset]:
|
|||
|
|
"""
|
|||
|
|
获取并解析模型配置
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
model_id: 模型功能ID
|
|||
|
|
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
结构化的模型配置或None(如果未找到)
|
|||
|
|
"""
|
|||
|
|
query = """
|
|||
|
|
select bpra.id,bpra.resource_original_path from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
|
|||
|
|
left join bz_datasets c on c.id =a.datasetid
|
|||
|
|
left join bz_dataset_project_relations d on d.data_set_id =c.id
|
|||
|
|
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
|
|||
|
|
where b.id=%s
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|||
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|||
|
|
cur.execute(query, (bz_training_task_id,))
|
|||
|
|
rows = cur.fetchall()
|
|||
|
|
return [self._db_row_to_dataset(row) for row in rows] # 转换为Dataset列表
|
|||
|
|
except psycopg2.Error as e:
|
|||
|
|
print(f"Database query error: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def get_labels(self, bz_training_task_id: int) -> List[Labels]:
|
|||
|
|
"""
|
|||
|
|
获取并解析模型配置
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
model_id: 模型功能ID
|
|||
|
|
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
结构化的模型配置或None(如果未找到)
|
|||
|
|
"""
|
|||
|
|
query = """
|
|||
|
|
select bpra.id,bpra.resource_original_path ,bar.resource_id , bar.label_set_id ,bar.label_id ,bar.annotation_data from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
|
|||
|
|
left join bz_datasets c on c.id =a.datasetid
|
|||
|
|
left join bz_dataset_project_relations d on d.data_set_id =c.id
|
|||
|
|
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
|
|||
|
|
left join bz_annotation_record bar on bar.task_assignment_id =bpra.id
|
|||
|
|
where b.id=%s
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|||
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|||
|
|
cur.execute(query, (bz_training_task_id,))
|
|||
|
|
rows = cur.fetchall()
|
|||
|
|
return [self._db_row_to_labels(row) for row in rows] # 转换为Dataset列表
|
|||
|
|
except psycopg2.Error as e:
|
|||
|
|
print(f"Database query error: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def get_label_yaml(self, bz_training_task_id: int) -> List[Label_Yaml]:
|
|||
|
|
"""
|
|||
|
|
获取并解析模型配置
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
model_id: 模型功能ID
|
|||
|
|
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
结构化的模型配置或None(如果未找到)
|
|||
|
|
"""
|
|||
|
|
query = """
|
|||
|
|
select id,name,e_name from bz_labels where id in (select distinct(bar.label_id) AS id from bz_training_dataset a left join bz_training_task b on b.id=a.trainingtaskid
|
|||
|
|
left join bz_datasets c on c.id =a.datasetid
|
|||
|
|
left join bz_dataset_project_relations d on d.data_set_id =c.id
|
|||
|
|
left join bz_project_resource_assignments bpra on bpra.project_id =d.project_id
|
|||
|
|
left join bz_annotation_record bar on bar.task_assignment_id =bpra.id
|
|||
|
|
where b.id=%s
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|||
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|||
|
|
cur.execute(query, (bz_training_task_id,))
|
|||
|
|
rows = cur.fetchall()
|
|||
|
|
return [self._db_row_to_label_yaml(row) for row in rows] # 转换为Dataset列表
|
|||
|
|
except psycopg2.Error as e:
|
|||
|
|
print(f"Database query error: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def _config_to_db_format(self, config: ModelConfiguration) -> Dict:
|
|||
|
|
"""将配置对象转换为数据库格式"""
|
|||
|
|
return {
|
|||
|
|
"model_id": config.model_info.id,
|
|||
|
|
"filter_cls": config.class_config.filter_indices,
|
|||
|
|
"func_description": config.model_info.func_description,
|
|||
|
|
"yolo_version": config.model_info.yolo_version,
|
|||
|
|
"path": config.model_info.model_path,
|
|||
|
|
"cls_index": config.class_config.class_indices,
|
|||
|
|
"cls": [cls_info.name for cls_info in config.class_config.classes],
|
|||
|
|
"cls_en": [cls_info.english_name for cls_info in config.class_config.classes],
|
|||
|
|
"cls_description": ", ".join(
|
|||
|
|
filter(None, [cls_info.description for cls_info in config.class_config.classes])
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def insert_train_pid(self, task_id, train_pid) -> bool:
|
|||
|
|
"""
|
|||
|
|
插入新的训练记录(task_id 和 train_pid)
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
config: 要插入的模型配置对象(需包含 task_id 和 train_pid)
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
是否插入成功
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
insert_sql = """
|
|||
|
|
INSERT INTO bz_train_record (
|
|||
|
|
task_id, train_pid,create_time
|
|||
|
|
) VALUES (
|
|||
|
|
%s, %s,now()
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
cur.execute(insert_sql, (task_id, train_pid))
|
|||
|
|
conn.commit()
|
|||
|
|
return True
|
|||
|
|
except psycopg2.Error as e:
|
|||
|
|
print(f"Database insert error: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def _db_row_to_dataset(self, row: Dict) -> ModelConfiguration:
|
|||
|
|
|
|||
|
|
return Dataset(
|
|||
|
|
id=row["id"],
|
|||
|
|
resource_original_path=row["resource_original_path"],
|
|||
|
|
pic_name=None,
|
|||
|
|
local_path=None,
|
|||
|
|
label_name=None,
|
|||
|
|
label_content="",
|
|||
|
|
label_txt_path=None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _db_row_to_labels(self, row: Dict) -> ModelConfiguration:
|
|||
|
|
|
|||
|
|
return Labels(
|
|||
|
|
id=row["id"],
|
|||
|
|
resource_original_path=row["resource_original_path"],
|
|||
|
|
resource_id=row["resource_id"],
|
|||
|
|
label_set_id=row["label_set_id"],
|
|||
|
|
label_ids=row["label_id"],
|
|||
|
|
annotation_data=row["annotation_data"]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _db_row_to_label_yaml(self, row: Dict) -> ModelConfiguration:
|
|||
|
|
|
|||
|
|
return Label_Yaml(
|
|||
|
|
id=row["id"],
|
|||
|
|
id_order=-1,
|
|||
|
|
name=row["name"],
|
|||
|
|
e_name=row["e_name"]
|
|||
|
|
|
|||
|
|
)
|
|||
|
|
|