676 lines
22 KiB
Python
676 lines
22 KiB
Python
import psycopg2
|
||
from psycopg2.extras import RealDictCursor
|
||
import json
|
||
from typing import Dict, List, Union, Optional
|
||
from dataclasses import dataclass, asdict
|
||
from datetime import datetime
|
||
import re
|
||
|
||
|
||
@dataclass
|
||
class ModelClassInfo:
|
||
index: int
|
||
name: str
|
||
english_name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class ClassConfig:
|
||
filter_indices: List[int]
|
||
class_indices: List[int]
|
||
classes: List[ModelClassInfo]
|
||
|
||
|
||
@dataclass
|
||
class ModelInfo:
|
||
id: int
|
||
yolo_version: str
|
||
model_path: str
|
||
func_description: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class ModelMetadata:
|
||
total_classes: int
|
||
created_at: Optional[datetime] = None
|
||
updated_at: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class ModelData:
|
||
id: int
|
||
yolo_version: str
|
||
model_path: str
|
||
engine_path: str
|
||
so_path: str
|
||
repeat_dis: float
|
||
repeat_time: float
|
||
func_description: Optional[str]
|
||
filter_indices: List[int]
|
||
class_indices: List[int]
|
||
conf: float
|
||
classes: List[ModelClassInfo]
|
||
total_classes: int
|
||
cls_names: {}
|
||
filtered_cls_en_dict: {}
|
||
cls_en_dict: {}
|
||
filtered_cls_dict: {}
|
||
cls_dict: {}
|
||
cls_str_dict: {}
|
||
cls_zn_to_eh_dict: {}
|
||
allowed_classes: []
|
||
created_at: Optional[datetime] = None
|
||
updated_at: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class MqttData:
|
||
mqtt_id: int
|
||
mqtt_ip: str
|
||
mqtt_port: int
|
||
mqtt_topic: str
|
||
mqtt_username: str
|
||
mqtt_pass: str
|
||
mqtt_description: str
|
||
org_code: str
|
||
mqtt_type: str
|
||
|
||
|
||
@dataclass
|
||
class Device:
|
||
dname: str
|
||
sn: str
|
||
orgcode: int
|
||
lat: float
|
||
lng: float
|
||
height: float
|
||
|
||
|
||
@dataclass
|
||
class ModelConfiguration:
|
||
model_info: ModelInfo
|
||
class_config: ClassConfig
|
||
metadata: ModelMetadata
|
||
|
||
|
||
class DateTimeEncoder(json.JSONEncoder):
|
||
"""自定义JSON编码器,用于处理datetime对象"""
|
||
|
||
def default(self, obj):
|
||
if isinstance(obj, datetime):
|
||
return obj.isoformat()
|
||
return super().default(obj)
|
||
|
||
|
||
class ModelConfigDAO:
|
||
def __init__(self, db_params: Dict[str, str]):
|
||
"""
|
||
初始化数据库连接
|
||
参数:
|
||
db_params: 数据库连接参数,包含:
|
||
- dbname: 数据库名
|
||
- user: 用户名
|
||
- password: 密码
|
||
- host: 主机地址
|
||
- port: 端口号
|
||
"""
|
||
self.db_params = db_params
|
||
self._create_table_if_not_exists()
|
||
|
||
def _create_table_if_not_exists(self):
|
||
"""确保表存在"""
|
||
create_table_sql = """
|
||
CREATE TABLE IF NOT EXISTS ai_model (
|
||
model_id INTEGER PRIMARY KEY,
|
||
filter_cls INTEGER[],
|
||
func_description TEXT,
|
||
yolo_version TEXT ,
|
||
path TEXT ,
|
||
cls_index INTEGER[] ,
|
||
cls TEXT[] ,
|
||
cls_en TEXT[] ,
|
||
cls_description TEXT,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
|
||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||
RETURNS TRIGGER AS $$
|
||
BEGIN
|
||
NEW.updated_at = CURRENT_TIMESTAMP;
|
||
RETURN NEW;
|
||
END;
|
||
$$ LANGUAGE plpgsql;
|
||
DROP TRIGGER IF EXISTS update_model_config_modtime ON ai_model;
|
||
CREATE TRIGGER update_model_config_modtime
|
||
BEFORE UPDATE ON ai_model
|
||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(create_table_sql)
|
||
conn.commit()
|
||
except psycopg2.Error as e:
|
||
print(f"Error creating table: {e}")
|
||
raise
|
||
|
||
def insert_config(self, config: ModelConfiguration) -> bool:
|
||
"""
|
||
插入新的模型配置
|
||
|
||
参数:
|
||
config: 要插入的模型配置对象
|
||
|
||
返回:
|
||
是否插入成功
|
||
"""
|
||
if not isinstance(config, ModelConfiguration):
|
||
raise ValueError("Invalid configuration type")
|
||
|
||
# 将对象转换为数据库格式
|
||
data = self._config_to_db_format(config)
|
||
|
||
query = """
|
||
INSERT INTO ai_model (
|
||
model_id, filter_cls, func_description,
|
||
yolo_version, path, cls_index, cls, cls_en, cls_description
|
||
) VALUES (
|
||
%(model_id)s, %(filter_cls)s, %(func_description)s,
|
||
%(yolo_version)s, %(path)s, %(cls_index)s, %(cls)s, %(cls_en)s, %(cls_description)s
|
||
)
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(query, data)
|
||
conn.commit()
|
||
return True
|
||
except psycopg2.Error as e:
|
||
print(f"Database insert error: {e}")
|
||
return False
|
||
|
||
def update_config(self, config: ModelConfiguration) -> bool:
|
||
"""
|
||
更新现有的模型配置
|
||
|
||
参数:
|
||
config: 要更新的模型配置对象
|
||
|
||
返回:
|
||
是否更新成功
|
||
"""
|
||
if not isinstance(config, ModelConfiguration):
|
||
raise ValueError("Invalid configuration type")
|
||
|
||
data = self._config_to_db_format(config)
|
||
|
||
query = """
|
||
UPDATE ai_model SET
|
||
filter_cls = %(filter_cls)s,
|
||
func_description = %(func_description)s,
|
||
yolo_version = %(yolo_version)s,
|
||
path = %(path)s,
|
||
cls_index = %(cls_index)s,
|
||
cls = %(cls)s,
|
||
cls_en = %(cls_en)s,
|
||
cls_description = %(cls_description)s
|
||
WHERE model_id = %(model_id)s
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(query, data)
|
||
conn.commit()
|
||
return True
|
||
except psycopg2.Error as e:
|
||
print(f"Database update error: {e}")
|
||
return False
|
||
|
||
def get_config(self, model_id: int, filter_indices: []) -> Optional[ModelConfiguration]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
SELECT
|
||
aml.func_id,
|
||
aml.model_func_id,
|
||
aml.model_id,
|
||
aml.confidence,
|
||
aml.py_func,
|
||
aml.repeat_dis,
|
||
aml.repeat_time,
|
||
am.scope,
|
||
am.yolo_version,
|
||
am.PATH,
|
||
am.engine_path,
|
||
am.so_path,
|
||
am.cls_index,
|
||
am.cls,
|
||
am.cls_en,
|
||
am.cls_description,
|
||
am.created_at,
|
||
am.updated_at
|
||
FROM
|
||
ai_model am,
|
||
ai_model_list aml
|
||
WHERE
|
||
aml.func_id = %s
|
||
AND aml.model_id = am.model_id
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (model_id,))
|
||
result = cur.fetchone()
|
||
return self._db_row_to_config(result, filter_indices) if result else None
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
def insert_request_log(self, task_id, sn, org_code, requset_json, request) -> bool:
|
||
"""
|
||
插入新的模型配置
|
||
|
||
参数:
|
||
config: 要插入的模型配置对象
|
||
|
||
返回:
|
||
是否插入成功
|
||
"""
|
||
|
||
query = """
|
||
insert into ai_request_log(task_id,sn,org_code,requset_json,request,create_time)values(%s,%s,%s,%s,%s,now())
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(query, (task_id, sn, org_code, requset_json, request,))
|
||
conn.commit()
|
||
return True
|
||
except psycopg2.Error as e:
|
||
print(f"Database insert error: {e}")
|
||
return False
|
||
|
||
def get_device(self, sn: str, orgcode: str) -> Optional[Device]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
sn: 机场sn
|
||
orgcode: 组织id
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select dname,sn,orgcode,lat,lng,height from device where sn=%s and orgcode=%s
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (sn, orgcode,))
|
||
result = cur.fetchone()
|
||
return self._db_row_to_device_config(result) if result else None
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
def get_mqtt_config(self, mqtt_id: int) -> Optional[ModelConfiguration]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select id,mqtt_ip,mqtt_port,mqtt_topic,mqtt_username,mqtt_pass,description from ai_mqtt_config where id=%s
|
||
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (mqtt_id,))
|
||
result = cur.fetchone()
|
||
return self._db_row_to_mqtt_config(result) if result else None
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
|
||
|
||
def get_pic_mqtt_config_by_orgcode(self, org_code: str,mqtt_type: str) -> Optional[MqttData]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select id,org_code,mqtt_type,mqtt_ip,mqtt_port,
|
||
mqtt_topic,mqtt_username,mqtt_pass,description
|
||
from ai_mqtt_config
|
||
where org_code=%s and mqtt_type=%s
|
||
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (org_code,mqtt_type,))
|
||
result = cur.fetchone()
|
||
return self._db_row_to_mqtt_config_by_orgcode(result) if result else None
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
|
||
def get_mqtt_config_by_orgcode(self, org_code: str, sn: str,mqtt_type: str) -> Optional[MqttData]:
|
||
"""
|
||
获取并解析模型配置
|
||
|
||
参数:
|
||
model_id: 模型功能ID
|
||
|
||
|
||
返回:
|
||
结构化的模型配置或None(如果未找到)
|
||
"""
|
||
query = """
|
||
select id,org_code,mqtt_type,mqtt_ip,mqtt_port,
|
||
mqtt_topic,mqtt_username,mqtt_pass,description
|
||
from ai_mqtt_config
|
||
where org_code=%s and sn=%s and mqtt_type=%s
|
||
|
||
"""
|
||
|
||
try:
|
||
with psycopg2.connect(**self.db_params) as conn:
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cur.execute(query, (org_code, sn,mqtt_type,))
|
||
result = cur.fetchone()
|
||
return self._db_row_to_mqtt_config_by_orgcode(result) if result else None
|
||
except psycopg2.Error as e:
|
||
print(f"Database query error: {e}")
|
||
return None
|
||
|
||
def _config_to_db_format(self, config: ModelConfiguration) -> Dict:
|
||
"""将配置对象转换为数据库格式"""
|
||
return {
|
||
"model_id": config.model_info.id,
|
||
"filter_cls": config.class_config.filter_indices,
|
||
"func_description": config.model_info.func_description,
|
||
"yolo_version": config.model_info.yolo_version,
|
||
"path": config.model_info.model_path,
|
||
"cls_index": config.class_config.class_indices,
|
||
"cls": [cls_info.name for cls_info in config.class_config.classes],
|
||
"cls_en": [cls_info.english_name for cls_info in config.class_config.classes],
|
||
"cls_description": ", ".join(
|
||
filter(None, [cls_info.description for cls_info in config.class_config.classes])
|
||
)
|
||
}
|
||
|
||
def _db_row_to_device_config(self, row: Dict) -> ModelConfiguration:
|
||
|
||
return Device(
|
||
# dname, sn, orgcode, lat, lng, height
|
||
dname=row["dname"],
|
||
sn=row["sn"],
|
||
orgcode=row["orgcode"],
|
||
lat=float(row["lat"]),
|
||
lng=float(row["lng"]),
|
||
height=float(row["height"])
|
||
)
|
||
|
||
def _db_row_to_mqtt_config(self, row: Dict) -> ModelConfiguration:
|
||
|
||
return MqttData(
|
||
mqtt_id=row["id"],
|
||
mqtt_ip=row["mqtt_ip"],
|
||
mqtt_port=int(row["mqtt_port"]),
|
||
mqtt_topic=row["mqtt_topic"],
|
||
mqtt_username=row["mqtt_username"],
|
||
mqtt_pass=row["mqtt_pass"],
|
||
mqtt_description=row["description"]
|
||
)
|
||
|
||
def _db_row_to_mqtt_config_by_orgcode(self, row: Dict) -> ModelConfiguration:
|
||
|
||
return MqttData(
|
||
mqtt_id=row["id"],
|
||
mqtt_ip=row["mqtt_ip"],
|
||
mqtt_port=int(row["mqtt_port"]),
|
||
mqtt_topic=row["mqtt_topic"],
|
||
mqtt_username=row["mqtt_username"],
|
||
mqtt_pass=row["mqtt_pass"],
|
||
mqtt_description=row["description"],
|
||
org_code=row["org_code"],
|
||
mqtt_type=row["mqtt_type"]
|
||
|
||
)
|
||
|
||
def _db_row_to_config(self, row: Dict, filter_cls: []) -> ModelConfiguration:
|
||
"""将数据库行转换为配置对象"""
|
||
|
||
# 解析描述字符串为列表
|
||
def parse_description(desc: str) -> List[str]:
|
||
return [d.strip() for d in desc.split(',') if d.strip()] if desc else []
|
||
|
||
# 获取各类别描述
|
||
descriptions = parse_description(row.get('cls_description', ''))
|
||
func_descriptions = parse_description(row.get('func_description', ''))
|
||
|
||
# 构建类信息列表
|
||
classes = []
|
||
cls_names = row['cls']
|
||
cls_en_names = row['cls_en']
|
||
|
||
for idx in row['cls_index']:
|
||
if idx >= len(cls_names):
|
||
continue
|
||
|
||
class_info = ModelClassInfo(
|
||
index=idx,
|
||
name=cls_names[idx],
|
||
english_name=cls_en_names[idx] if idx < len(cls_en_names) else None,
|
||
)
|
||
# 添加描述(如果存在)
|
||
if descriptions and idx < len(descriptions):
|
||
class_info.description = descriptions[idx]
|
||
elif func_descriptions and idx < len(func_descriptions):
|
||
class_info.description = func_descriptions[idx % len(func_descriptions)]
|
||
classes.append(class_info)
|
||
|
||
# text = "识别内容=0:鳄鱼纹裂缝;1:纵向裂缝;2:斜裂缝;3:坑洞;4:修补;5:横向裂缝"
|
||
text = row['py_func']
|
||
# 提取所有数字和描述
|
||
matches = re.findall(r'(\d+):([^;]+)', text)
|
||
# 转换为字典(数字为键,描述为值)
|
||
# result = {int(num): desc for num, desc in matches}
|
||
# {0: '鳄鱼纹裂缝', 1: '纵向裂缝', 2: '斜裂缝', 3: '坑洞', 4: '修补', 5: '横向裂缝'}
|
||
filter_indices = [int(num) for num, desc in matches]
|
||
if filter_cls:
|
||
filter_indices = filter_cls
|
||
|
||
# filter_indices = row['filter_cls']
|
||
cls_index = row['cls_index']
|
||
cls_en = row['cls_en']
|
||
cls = row['cls']
|
||
conf = row['confidence']
|
||
repeat_dis = row["repeat_dis"]
|
||
# 校验需要过滤的下标,是否在总的下标内
|
||
valid_filter_cls = [idx for idx in filter_indices if idx in cls_index]
|
||
|
||
# {0:'人’1:'人',2:'自行车’3:"汽车’,4:'厢型车’5:'卡车,6:'三轮车''三轮车’8:公交,9:摩托”
|
||
filtered_cls_en_dict = {
|
||
idx: cls_en[cls_index.index(idx)]
|
||
for idx in valid_filter_cls
|
||
}
|
||
|
||
cls_en_dict = {
|
||
idx: cls_en[cls_index.index(idx)]
|
||
for idx in cls_index
|
||
}
|
||
|
||
# {0: 'pedestrian', 1: 'people', 2: "bicycle', 3: 'car', 4: 'van', 5: "truck', 6:ricycle',7:'awning-tricycle', 8: 'bus', 9: 'motor'}
|
||
filtered_cls_dict = {
|
||
idx: cls[cls_index.index(idx)]
|
||
for idx in valid_filter_cls
|
||
}
|
||
cls_dict = {
|
||
idx: cls[cls_index.index(idx)]
|
||
for idx in cls_index
|
||
}
|
||
|
||
cls_str_dict = {idx: name for idx, name in zip(cls, cls_en)}
|
||
|
||
valid_pairs = []
|
||
for idx in filter_indices:
|
||
if idx in cls_index:
|
||
en_name = cls_en[cls_index.index(idx)]
|
||
zh_name = cls[cls_index.index(idx)]
|
||
valid_pairs.append((zh_name, en_name))
|
||
|
||
cls_zn_to_eh_dict = {}
|
||
for en_name, zh_name in valid_pairs:
|
||
if zh_name not in cls_zn_to_eh_dict.values():
|
||
# 如果中文名还未在字典的值中,直接添加
|
||
cls_zn_to_eh_dict[en_name] = zh_name
|
||
allowed_classes = [cls[i] for i in filter_indices]
|
||
# 构建完整的配置对象
|
||
|
||
return ModelData(
|
||
id=row['model_id'],
|
||
yolo_version=row['yolo_version'],
|
||
model_path=row['path'],
|
||
engine_path=row['engine_path'],
|
||
so_path=row['so_path'],
|
||
func_description=row.get('scope'),
|
||
# filter_indices=row['filter_cls'],
|
||
filter_indices=filter_indices,
|
||
repeat_dis=repeat_dis,
|
||
repeat_time=row.get('repeat_time'),
|
||
class_indices=row['cls_index'],
|
||
conf=conf,
|
||
classes=classes,
|
||
total_classes=len(classes),
|
||
created_at=row.get('created_at'),
|
||
updated_at=row.get('updated_at'),
|
||
cls_names=cls_names,
|
||
filtered_cls_en_dict=filtered_cls_en_dict,
|
||
cls_en_dict=cls_en_dict,
|
||
filtered_cls_dict=filtered_cls_dict,
|
||
cls_dict=cls_dict,
|
||
cls_str_dict=cls_str_dict,
|
||
cls_zn_to_eh_dict=cls_zn_to_eh_dict,
|
||
allowed_classes=allowed_classes
|
||
)
|
||
|
||
# 示例数据
|
||
sample_data = ModelConfiguration(
|
||
model_info=ModelInfo(
|
||
id=100002,
|
||
yolo_version="11",
|
||
model_path="pt/GDCL.pt",
|
||
func_description="挖掘机,自卸卡车,压路机,移动式起重机,固定式起重机,轮式装载机,混凝土搅拌车,反铲式装载机,推土机,平地机"
|
||
),
|
||
class_config=ClassConfig(
|
||
filter_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||
class_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||
classes=[
|
||
ModelClassInfo(index=0, name="Excavator", english_name="挖掘机", description="挖掘机"),
|
||
ModelClassInfo(index=1, name="Flatbed", english_name="卡车"),
|
||
ModelClassInfo(index=2, name="Truck", english_name="卡车"),
|
||
ModelClassInfo(index=3, name="DumpTruck", english_name="卡车"),
|
||
ModelClassInfo(index=4, name="Roller", english_name="压路机"),
|
||
ModelClassInfo(index=5, name="CraneMobile", english_name="起重机"),
|
||
ModelClassInfo(index=6, name="CraneTruck", english_name="起重机"),
|
||
ModelClassInfo(index=7, name="CraneTower", english_name="塔吊"),
|
||
ModelClassInfo(index=8, name="CraneFixed", english_name="塔吊"),
|
||
ModelClassInfo(index=9, name="LoaderWheel", english_name="塔吊"),
|
||
ModelClassInfo(index=10, name="CementTanker", english_name="水泥罐车"),
|
||
ModelClassInfo(index=11, name="ConcreteMixer", english_name="水泥罐车"),
|
||
ModelClassInfo(index=12, name="MixerTruck", english_name="水泥罐车"),
|
||
ModelClassInfo(index=13, name="BackhoeLoader", english_name="挖土机"),
|
||
ModelClassInfo(index=14, name="LoaderWheel", english_name="推土机"),
|
||
ModelClassInfo(index=15, name="Grader", english_name="地坪机"),
|
||
ModelClassInfo(index=16, name="PileDriver", english_name="打桩机")
|
||
]
|
||
),
|
||
metadata=ModelMetadata(total_classes=17)
|
||
)
|
||
|
||
# 示例使用
|
||
if __name__ == "__main__":
|
||
# 数据库连接配置(请替换为实际参数)
|
||
DB_CONFIG = {
|
||
"dbname": "smart_dev_123",
|
||
"user": "postgres",
|
||
"password": "root",
|
||
"host": "8.137.54.85",
|
||
"port": "5060"
|
||
}
|
||
|
||
# 创建DAO实例
|
||
dao = ModelConfigDAO(DB_CONFIG)
|
||
|
||
# # 示例1: 插入新配置
|
||
# print("插入新配置...")
|
||
# if dao.insert_config(sample_data):
|
||
# print("插入成功")
|
||
# else:
|
||
# print("插入失败")
|
||
|
||
# 示例2: 查询配置
|
||
print("\n查询配置...")
|
||
model_id = 100091
|
||
config = dao.get_config(model_id)
|
||
|
||
if config:
|
||
# 打印结构化结果(使用自定义编码器处理datetime)
|
||
print(json.dumps(asdict(config), indent=2, ensure_ascii=False, cls=DateTimeEncoder))
|
||
|
||
# 访问特定字段
|
||
print("\n模型路径:", config.model_path)
|
||
print("过滤类别:", config.filter_indices)
|
||
print("第一个类别:", asdict(config.classes[0]))
|
||
print("创建时间:", config.created_at)
|
||
print("更新时间:", config.updated_at)
|
||
else:
|
||
print(f"未找到ID为 {model_id} 的模型配置")
|
||
#
|
||
# # 示例3: 更新配置
|
||
# print("\n更新配置...")
|
||
# if config:
|
||
# config.model_info.yolo_version = "12"
|
||
# if dao.update_config(config):
|
||
# print("更新成功")
|
||
# # 验证更新
|
||
# updated_config = dao.get_config(model_id)
|
||
# print("更新后的YOLO版本:", updated_config.model_info.yolo_version)
|
||
# print("更新时间:", updated_config.metadata.updated_at)
|
||
# else:
|
||
# print("更新失败")
|