109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
import json
|
|
from typing import Dict, List, Union, Optional
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
import re
|
|
|
|
|
|
@dataclass
|
|
class RecognitionTask:
|
|
"""识别任务"""
|
|
|
|
task_name: Optional[str] = None
|
|
model_id: Optional[int] = None
|
|
model_name: Optional[str] = None
|
|
model_version_id: Optional[str] = None
|
|
created_by: Optional[str] = None
|
|
status: Optional[int] = None
|
|
|
|
id: Optional[int] = None
|
|
exec_msg: Optional[str] = None
|
|
created_at: Optional[datetime] = None
|
|
result_url: Optional[str] = None
|
|
source_url: Optional[str] = None
|
|
task_id: Optional[str] = None
|
|
resource_record_id: Optional[int] = None
|
|
|
|
class RecognitionTaskDAO:
|
|
def __init__(self, db_params: Dict[str, str]):
|
|
"""
|
|
初始化数据库连接
|
|
参数:
|
|
db_params: 数据库连接参数,包含:
|
|
- dbname: 数据库名
|
|
- user: 用户名
|
|
- password: 密码
|
|
- host: 主机地址
|
|
- port: 端口号
|
|
"""
|
|
self.db_params = db_params
|
|
|
|
def update_recognition_task(self, model: RecognitionTask) -> bool:
|
|
"""
|
|
更新现有的识别任务
|
|
|
|
参数:
|
|
model: 要更新的识别任务对象
|
|
|
|
返回:
|
|
是否更新成功
|
|
"""
|
|
if not isinstance(model, RecognitionTask):
|
|
raise ValueError("Invalid configuration type")
|
|
|
|
data = self._to_db_format(model)
|
|
|
|
query = """
|
|
UPDATE bz_recognition_tasks SET
|
|
status = %(status)s
|
|
WHERE task_id = %(task_id)s
|
|
"""
|
|
|
|
try:
|
|
with psycopg2.connect(**self.db_params) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(query, data)
|
|
conn.commit()
|
|
return True
|
|
except psycopg2.Error as e:
|
|
print(f"Database update error: {e}")
|
|
return False
|
|
|
|
def _to_db_format(self, task: RecognitionTask) -> Dict:
|
|
"""将RecognitionTask对象转换为数据库格式"""
|
|
return {
|
|
"id": task.id,
|
|
"task_name": task.task_name,
|
|
"model_id": task.model_id,
|
|
"model_name": task.model_name,
|
|
"model_version_id": task.model_version_id,
|
|
"status": task.status,
|
|
"exec_msg": task.exec_msg,
|
|
"created_at": task.created_at,
|
|
"created_by": task.created_by,
|
|
"result_url": task.result_url,
|
|
"source_url": task.source_url,
|
|
"task_id": task.task_id,
|
|
"resource_record_id": task.resource_record_id
|
|
}
|
|
|
|
|
|
def _from_db_format(self, db_data: Dict) -> RecognitionTask:
|
|
"""从数据库格式转换为RecognitionTask对象"""
|
|
return RecognitionTask(
|
|
id=db_data.get("id"),
|
|
task_name=db_data.get("task_name", ""),
|
|
model_id=db_data.get("model_id", 0),
|
|
model_name=db_data.get("model_name", ""),
|
|
model_version_id=db_data.get("model_version_id", ""),
|
|
status=db_data.get("status", 0),
|
|
exec_msg=db_data.get("exec_msg"),
|
|
created_at=db_data.get("created_at"),
|
|
created_by=db_data.get("created_by", ""),
|
|
result_url=db_data.get("result_url"),
|
|
source_url=db_data.get("source_url"),
|
|
task_id=db_data.get("task_id"),
|
|
resource_record_id=db_data.get("resource_record_id")
|
|
) |