ai-train_platform/middleware/recognition_task.py

109 lines
3.5 KiB
Python

import psycopg2
from psycopg2.extras import RealDictCursor
import json
from typing import Dict, List, Union, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import re
@dataclass
class RecognitionTask:
"""识别任务"""
task_name: Optional[str] = None
model_id: Optional[int] = None
model_name: Optional[str] = None
model_version_id: Optional[str] = None
created_by: Optional[str] = None
status: Optional[int] = None
id: Optional[int] = None
exec_msg: Optional[str] = None
created_at: Optional[datetime] = None
result_url: Optional[str] = None
source_url: Optional[str] = None
task_id: Optional[str] = None
resource_record_id: Optional[int] = None
class RecognitionTaskDAO:
def __init__(self, db_params: Dict[str, str]):
"""
初始化数据库连接
参数:
db_params: 数据库连接参数,包含:
- dbname: 数据库名
- user: 用户名
- password: 密码
- host: 主机地址
- port: 端口号
"""
self.db_params = db_params
def update_recognition_task(self, model: RecognitionTask) -> bool:
"""
更新现有的识别任务
参数:
model: 要更新的识别任务对象
返回:
是否更新成功
"""
if not isinstance(model, RecognitionTask):
raise ValueError("Invalid configuration type")
data = self._to_db_format(model)
query = """
UPDATE bz_recognition_tasks SET
status = %(status)s
WHERE task_id = %(task_id)s
"""
try:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
cur.execute(query, data)
conn.commit()
return True
except psycopg2.Error as e:
print(f"Database update error: {e}")
return False
def _to_db_format(self, task: RecognitionTask) -> Dict:
"""将RecognitionTask对象转换为数据库格式"""
return {
"id": task.id,
"task_name": task.task_name,
"model_id": task.model_id,
"model_name": task.model_name,
"model_version_id": task.model_version_id,
"status": task.status,
"exec_msg": task.exec_msg,
"created_at": task.created_at,
"created_by": task.created_by,
"result_url": task.result_url,
"source_url": task.source_url,
"task_id": task.task_id,
"resource_record_id": task.resource_record_id
}
def _from_db_format(self, db_data: Dict) -> RecognitionTask:
"""从数据库格式转换为RecognitionTask对象"""
return RecognitionTask(
id=db_data.get("id"),
task_name=db_data.get("task_name", ""),
model_id=db_data.get("model_id", 0),
model_name=db_data.get("model_name", ""),
model_version_id=db_data.get("model_version_id", ""),
status=db_data.get("status", 0),
exec_msg=db_data.get("exec_msg"),
created_at=db_data.get("created_at"),
created_by=db_data.get("created_by", ""),
result_url=db_data.get("result_url"),
source_url=db_data.get("source_url"),
task_id=db_data.get("task_id"),
resource_record_id=db_data.get("resource_record_id")
)