diff --git a/Ai_tottle/aboutdataset/.codemap/main-panel.json b/Ai_tottle/aboutdataset/.codemap/main-panel.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/Ai_tottle/aboutdataset/.codemap/main-panel.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Ai_tottle/aboutdataset/.vscode/settings.json b/Ai_tottle/aboutdataset/.vscode/settings.json new file mode 100644 index 0000000..a8c2003 --- /dev/null +++ b/Ai_tottle/aboutdataset/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [] +} \ No newline at end of file diff --git a/Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc b/Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc new file mode 100644 index 0000000..1cb1f22 Binary files /dev/null and b/Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc differ diff --git a/Ai_tottle/aboutdataset/__pycache__/pgadmin_helper.cpython-312.pyc b/Ai_tottle/aboutdataset/__pycache__/pgadmin_helper.cpython-312.pyc new file mode 100644 index 0000000..e63a250 Binary files /dev/null and b/Ai_tottle/aboutdataset/__pycache__/pgadmin_helper.cpython-312.pyc differ diff --git a/Ai_tottle/aboutdataset/config-test.yaml b/Ai_tottle/aboutdataset/config-test.yaml new file mode 100644 index 0000000..36756db --- /dev/null +++ b/Ai_tottle/aboutdataset/config-test.yaml @@ -0,0 +1,13 @@ +minio: + endpoint: "222.212.85.86:9000" + access_key: "adminjdskfj" + secret_key: "123456ksldjfal@Y" + secure: false + web: "http://222.212.85.86" + +sql: + host: '8.137.54.85' + port: 5060 + dbname: 'postgres' + user: 'postgres' + password: 'root' \ No newline at end of file diff --git a/Ai_tottle/aboutdataset/config.yaml b/Ai_tottle/aboutdataset/config.yaml new file mode 100644 index 0000000..2fa2eda --- /dev/null +++ b/Ai_tottle/aboutdataset/config.yaml @@ -0,0 +1,13 @@ +minio: + endpoint: "222.212.85.86:9000" + access_key: "adminjdskfj" + secret_key: "123456ksldjfal@Y" + secure: false + web: "http://222.212.85.86" + +sql: + host: '222.212.85.86' + port: 5432 + dbname: 'smart_dev' + user: 'postgres' + password: 'root' \ No newline at end of file diff --git a/Ai_tottle/aboutdataset/config_test_dev.yaml b/Ai_tottle/aboutdataset/config_test_dev.yaml new file mode 100644 index 0000000..650cdd6 --- /dev/null +++ b/Ai_tottle/aboutdataset/config_test_dev.yaml @@ -0,0 +1,13 @@ +minio: + endpoint: "222.212.85.86:9000" + access_key: "adminjdskfj" + secret_key: "123456ksldjfal@Y" + secure: false + web: "http://222.212.85.86" + +sql: + host: '8.137.54.85' + port: 5060 + dbname: 'smart_dev_123' + user: 'postgres' + password: 'root' \ No newline at end of file diff --git a/Ai_tottle/aboutdataset/download_oss.py b/Ai_tottle/aboutdataset/download_oss.py new file mode 100644 index 0000000..e1bf287 --- /dev/null +++ b/Ai_tottle/aboutdataset/download_oss.py @@ -0,0 +1,103 @@ +import psycopg2 +import os +import requests +from miniohelp import read_sql_config +import shutil + +def fetch_object_and_labels(sql_config, where_clause=None, table_name='aidataset'): + conn = psycopg2.connect( + host=sql_config['host'], + port=sql_config['port'], + user=sql_config['user'], + password=sql_config['password'], + dbname=sql_config.get('database') or sql_config.get('dbname') + ) + try: + cursor = conn.cursor() + sql = f"SELECT * FROM public.{table_name}" + if where_clause: + sql += f" WHERE {where_clause}" + + cursor.execute(sql) + rows = cursor.fetchall() + return rows + finally: + conn.close() + + +def download_file_from_url(url, save_path): + """ + 从 URL 下载文件到指定路径 + """ + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + with open(save_path, 'wb') as f: + for chunk in response.iter_content(8192): + if chunk: + f.write(chunk) + print(f"✅ 下载成功:{save_path}") + return True + except Exception as e: + print(f"❌ 下载失败: {url} 错误: {e}") + return False + +def save_label(label_text, file_path): + """ + 保存 label 内容到 .txt 文件 + """ + try: + with open(file_path, 'w', encoding='utf-8') as f: + f.write(label_text.strip()) + print(f"✅ 标签已保存:{file_path}") + except Exception as e: + print(f"❌ 保存标签失败:{file_path} 错误: {e}") + +def move_txt_files(src_dir='image', dst_dir='label'): + os.makedirs(dst_dir, exist_ok=True) + for filename in os.listdir(src_dir): + if filename.endswith('.txt'): + src_path = os.path.join(src_dir, filename) + dst_path = os.path.join(dst_dir, filename) + shutil.move(src_path, dst_path) + print(f"已移动: {src_path} -> {dst_path}") + + +def download_and_save_images_from_oss(yaml_name, where_clause, image_dir, label_dir, table_name): + sql_config = read_sql_config(yaml_name) + + rows = fetch_object_and_labels(sql_config, where_clause, table_name) + print(f"共查询到 {len(rows)} 条记录") + + os.makedirs(image_dir, exist_ok=True) + os.makedirs(label_dir, exist_ok=True) + + for id, orgcode, model, state, objectname, label in rows: + if label is None or label.strip() == "": + # 跳过无标签的 + print(f"跳过无标签记录 id={id}") + continue + + full_url = f"{objectname}" + filename = os.path.basename(objectname) + name_no_ext = os.path.splitext(filename)[0] + + image_path = os.path.join(image_dir, filename) + label_path = os.path.join(label_dir, name_no_ext + ".txt") + + success = download_file_from_url(full_url, image_path) + if success: + save_label(label, label_path) + print("下载完成/n 下载完成") + + +if __name__ == '__main__': + yaml_name='config' + where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'" + image_dir='images' + label_dir='labels' + table_name = 'aidataset' + download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name) diff --git a/Ai_tottle/aboutdataset/miniohelp.py b/Ai_tottle/aboutdataset/miniohelp.py new file mode 100644 index 0000000..71d9ba4 --- /dev/null +++ b/Ai_tottle/aboutdataset/miniohelp.py @@ -0,0 +1,168 @@ +from minio import Minio +from minio.error import S3Error +from urllib.parse import urlparse +import os +import requests +import yaml + +# 读取配置并初始化 MinIO 客户端 +def load_config(yaml_name): + yaml_path = f"{yaml_name}.yaml" + # 读取 YAML 配置文件 + with open(yaml_path, 'r', encoding='utf-8') as file: + config = yaml.safe_load(file) + + # 获取 MinIO 配置 + minio_config = config["minio"] + + # 初始化 MinIO 客户端 + minio_client = Minio( + endpoint=minio_config["endpoint"], + access_key=minio_config["access_key"], + secret_key=minio_config["secret_key"], + secure=minio_config.get("secure", False) + ) + + return minio_client + +def read_sql_config(yaml_name): + yaml_path = f"{yaml_name}.yaml" + with open(yaml_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + sql_config = config.get('sql') + if not sql_config: + raise ValueError("未找到 'sql' 配置块") + + # 优先使用 database,没有则用 dbname + if 'database' not in sql_config and 'dbname' not in sql_config: + raise KeyError("配置文件中缺少 'database' 或 'dbname' 字段") + + return sql_config + + +def parse_minio_url(minio_url): + """解析 MinIO 完整 URL,返回 bucket_name 和 object_name""" + path = urlparse(minio_url).path.lstrip("/") + parts = path.split("/", 1) + if len(parts) != 2: + raise ValueError("URL 格式错误,无法提取 bucket 和 object") + return parts[0], parts[1] + +def create_bucket(client): + """访问 MinIO 服务器,打印存储桶""" + try: + buckets = client.list_buckets() + for bucket in buckets: + print(f"Bucket: {bucket.name}, Created: {bucket.creation_date}") + except S3Error as e: + print(f"Error: {e}") + +def downFile(client, object_name, bucket_name, local_path=None): + """下载文件,可指定 bucket 和本地保存路径""" + current_directory = os.path.dirname(os.path.abspath(__file__)) + download_path = local_path or os.path.join(current_directory, os.path.basename(object_name)) + os.makedirs(os.path.dirname(download_path), exist_ok=True) + + try: + client.fget_object(bucket_name, object_name, download_path) + print(f"✅ 文件已成功下载到 {download_path}") + return download_path + except S3Error as e: + print(f"❌ 下载文件时出错: {e}") + return None + +def upload_file(client, file_path, bucket_name, bucket_directory): + """上传文件到指定 MinIO 存储桶和目录""" + try: + if not client.bucket_exists(bucket_name): + print(f"❌ 存储桶 {bucket_name} 不存在") + return None + + file_name = os.path.basename(file_path) + object_name = f"{bucket_directory.rstrip('/')}/{file_name}" if bucket_directory else file_name + + client.fput_object(bucket_name, object_name, file_path) + print(f"✅ 文件已上传至 {bucket_name}/{object_name}") + return f"{object_name}" + except S3Error as e: + print(f"❌ 上传文件时出错: {e}") + return None + +def upload_file_t(client, file_path, time_str, bucket_name, bucket_directory): + """上传文件到指定 MinIO 存储桶和目录(带时间路径)""" + try: + if not client.bucket_exists(bucket_name): + print(f"❌ 存储桶 {bucket_name} 不存在") + return None + + file_name = os.path.basename(file_path) + object_name = f"{bucket_directory.rstrip('/')}/{time_str}/{file_name}" if bucket_directory else file_name + + client.fput_object(bucket_name, object_name, file_path) + print(f"✅ 文件已上传至 {bucket_name}/{object_name}") + return f"{bucket_name}/{object_name}" + except S3Error as e: + print(f"❌ 上传文件时出错: {e}") + return None + +def upload_folder(client, folder_path, bucket_name, bucket_directory): + """上传整个文件夹到指定 MinIO 存储桶目录""" + try: + if not client.bucket_exists(bucket_name): + print(f"❌ 存储桶 {bucket_name} 不存在") + return + + for root, _, files in os.walk(folder_path): + for file in files: + file_path = os.path.join(root, file) + object_name = os.path.relpath(file_path, folder_path).replace("\\", "/") + + if bucket_directory: + object_name = f"{bucket_directory.rstrip('/')}/{object_name}" + + client.fput_object(bucket_name, object_name, file_path) + print(f"✅ 文件 {file_path} 已上传至 {bucket_name}/{object_name}") + except S3Error as e: + print(f"❌ 上传文件夹时出错: {e}") + +def download_file_url(url: str, save_path: str): + """ + 从指定 URL 下载文件并保存到本地路径。 + + :param url: 文件的 URL。 + :param save_path: 本地保存路径(包括文件名)。 + """ + try: + response = requests.get(url, stream=True) + response.raise_for_status() # 如果请求失败会抛出异常 + + with open(save_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + print(f"文件已成功下载到: {save_path}") + except requests.exceptions.RequestException as e: + print(f"下载失败: {e}") + + +if __name__ == '__main__': + # 初始化 MinIO 客户端 + client = Minio( + endpoint="222.212.85.86", + access_key="adminjdskfj", + secret_key="123456ksldjfal@Y", + secure=False + ) + + # 示例:通过完整 URL 下载文件 + full_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/media/86961d3a-8790-4bc0-bfc6-dc45a8e4b9bd/DJI_202504291707_001_86961d3a-8790-4bc0-bfc6-dc45a8e4b9bd/DJI_20250429170933_0001_V.jpeg" + bucket_name, object_name = parse_minio_url(full_url) + downFile(client, object_name, bucket_name) + + # 示例:上传文件 + file_path = r"D:\work\develop\sample.jpg" # 修改为实际文件 + upload_file(client, file_path, "300bdf2b-a150-406e-be63-d28bd29b409f", "media/output_video") + + # 示例:上传文件夹 + folder_path = r"D:\work\develop" + upload_folder(client, folder_path, "300bdf2b-a150-406e-be63-d28bd29b409f", "media/output_video") diff --git a/Ai_tottle/aboutdataset/pgadmin_helper.py b/Ai_tottle/aboutdataset/pgadmin_helper.py new file mode 100644 index 0000000..d643206 --- /dev/null +++ b/Ai_tottle/aboutdataset/pgadmin_helper.py @@ -0,0 +1,75 @@ +import psycopg2 +from datetime import datetime +from miniohelp import * + +def insert_data(table_name, data, table = 'danger'): + """ + 向 PostgreSQL 数据库的指定表中插入数据 + :param table_name: 表名 + :param data: 字典格式的数据(列名: 值) + """ + # 验证表名是否合法 + if table_name not in [table]: # 在这里加入你的表名 + print(f"错误: 无效的表名 {table_name}") + return + + try: + sql_yaml = data.pop('sql_yaml') + # 数据库配置 + sql_config = read_sql_config(sql_yaml) + db_host = sql_config['host'] + db_port = sql_config['port'] + db_database = sql_config['database'] + db_user = sql_config['user'] + db_password = sql_config['password'] + # 连接数据库 + with psycopg2.connect( + host=db_host, + port=db_port, + database=db_database, + user=db_user, + password=db_password + ) as conn: + with conn.cursor() as cursor: + # 生成 SQL 语句 + columns = ', '.join(data.keys()) + values = ', '.join(['%s'] * len(data)) + sql = f"INSERT INTO {table_name} ({columns}) VALUES ({values})" + + # 执行 SQL 语句 + cursor.execute(sql, tuple(data.values())) + + # 提交事务 + conn.commit() + print("数据插入成功") + + except psycopg2.Error as e: + print("数据库错误:", e.pgcode, e.pgerror) # 打印详细的错误码和错误信息 + +# 示例数据 +data = { + 'orgcode': 'bdzl', + 'sn': '123456', + 'snname': '测试机', + 'spaceid': '97123d28-40a3-4dd1-9021-616a7d60dce8', + 'spacename': '目标检测', + 'guid': '97123d28-40a3-4dd1-9021-616a7d60dce8', + 'level': '一般', + 'title': 'pedestrian', + 'dangertype': '目标', + 'content': '目标检测', + 'gps1': "", + 'dealresult': "", + 'imgobjectname': 'AIResults/20250320/97123d28-40a3-4dd1-9021-616a7d60dce8.jpg', + 'state': 0, + 'createtm': '2025-03-20 16:12:04', + 'dealtm': '2025-03-20 16:12:04', + 'createuser': "", + 'dealuser': "", + 'checkuser': "", + 'remarks': "", + 'sql_yaml': "", # 数据库配置文件路径 +} + +# 调用插入函数 +insert_data('danger', data) diff --git a/Ai_tottle/aboutdataset/upload.py b/Ai_tottle/aboutdataset/upload.py new file mode 100644 index 0000000..10790d2 --- /dev/null +++ b/Ai_tottle/aboutdataset/upload.py @@ -0,0 +1,138 @@ +import os +import psycopg2 +from miniohelp import load_config, upload_file + +# ------------------ 加载数据库配置 ------------------ +def load_sql_config(yaml_name): + import yaml + # 从 yaml 文件中读取 sql 配置 + with open(f"{yaml_name}.yaml", "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + return config["sql"] + +# ------------------ 插入单条数据到数据库 ------------------ +def insert_to_database(conn, table, data): + # 定义插入 SQL + sql = f""" + INSERT INTO {table} ( + id, orgcode, model, state, objectname, label + ) VALUES (%s, %s, %s, %s, %s, %s) + """ + try: + # 使用游标执行 SQL + with conn.cursor() as cur: + cur.execute(sql, data) + # 提交事务 + conn.commit() + print("✅ 数据插入成功") + return True + except psycopg2.Error as e: + # 出现错误时回滚事务 + conn.rollback() + print("❌ 插入失败:", e) + return False + +# ------------------ 主流程:上传图片 + 插入数据库 ------------------ +def upload_and_insert_images_with_labels( + yaml_name, # 配置文件名(不带 .yaml 后缀) + image_dir, # 图片目录 + label_dir, # 标签目录 + bucket_name, # MinIO 存储桶名 + bucket_path, # MinIO 上传路径(目录) + table_name, # 数据库表名 + model_name, # 模型名,用于区分不同模型的数据 + orgcode="bdzl" # 机构代码,默认 bdzl +): + # 1. 初始化 MinIO 和 SQL 配置 + minio_client = load_config(yaml_name) # 加载 MinIO 配置 + sql_conf = load_sql_config(yaml_name) # 加载 SQL 配置 + + # 2. 创建数据库连接 + conn = psycopg2.connect( + host=sql_conf["host"], + port=sql_conf["port"], + user=sql_conf["user"], + password=sql_conf["password"], + database=sql_conf["dbname"] + ) + + # 3. 查询数据库中当前最大 id 和已存在的文件 + with conn.cursor() as cur: + # 查询表里当前最大 id + cur.execute(f"SELECT COALESCE(MAX(id), 0) FROM {table_name}") + max_id = cur.fetchone()[0] or 0 + next_id = max_id + 1 + + # 查询该模型已有的文件名(仅文件名部分) + cur.execute(f"SELECT objectname FROM {table_name} WHERE model = %s", (model_name,)) + existing_files = set([os.path.basename(row[0]) for row in cur.fetchall()]) + + # 4. 遍历图片目录 + for filename in os.listdir(image_dir): + # 只处理图片类型文件 + if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")): + continue + + # 如果数据库已存在该文件,跳过 + if filename in existing_files: + print(f"⏭️ 已存在于数据库中,跳过:{filename}") + continue + + image_path = os.path.join(image_dir, filename) + name_no_ext = os.path.splitext(filename)[0] + label_path = os.path.join(label_dir, name_no_ext + ".txt") + + # 标签文件不存在,跳过 + if not os.path.exists(label_path): + print(f"⚠️ 未找到标签:{label_path},跳过") + continue + + # 读取标签文件内容 + with open(label_path, "r", encoding="utf-8") as f: + label_content = f.read().strip() + + # 5. 上传图片到 MinIO + objectname = upload_file(minio_client, image_path, bucket_name, bucket_path) + if not objectname: + print(f"❌ 上传失败:{filename}") + continue + + # 6. 插入数据库 + try: + data = ( + next_id, # id + orgcode, # 机构代码 + model_name, # 模型 + 1, # state 固定 1 + objectname, # MinIO 返回的文件路径 + label_content # 标签内容 + ) + info = insert_to_database(conn, table_name, data) + print(info) + print(f"✅ 成功插入数据库 (id={next_id}):{filename}") + next_id += 1 + except Exception as e: + print(f"❌ 数据库插入失败:{filename}, error: {e}") + + # 7. 提交事务并关闭连接 + conn.commit() + conn.close() + print("🎉 所有图片和标签处理完成!") + +# ------------------ 示例调用 ------------------ +if __name__ == "__main__": + yaml_name = "config_test_dev" # 配置文件名(不含 .yaml 后缀) + image_dir = r"D:\dataset\images\train" # 图片路径 + label_dir = r"D:\dataset\labels\train" # 标签路径 + bucket_name = "300bdf2b-a150-406e-be63-d28bd29b409f" # MinIO 桶名 + bucket_directory = "new_datasets/fence" # MinIO 内的目录 + table_name = "public.aidataset" # 数据库表名 + model_name = "08ff91fd-60d2-470f-9675-b18800229654" # 模型uuid + orgcode = "bdzl" # 机构代码 + + # 调用主函数 + upload_and_insert_images_with_labels( + yaml_name, image_dir, label_dir, + bucket_name, bucket_directory, + table_name, model_name, orgcode + ) diff --git a/Ai_tottle/aboutdataset/删除无对应图片的txt.py b/Ai_tottle/aboutdataset/删除无对应图片的txt.py new file mode 100644 index 0000000..5dd0690 --- /dev/null +++ b/Ai_tottle/aboutdataset/删除无对应图片的txt.py @@ -0,0 +1,18 @@ +import os + +# 设置图片和标签文件夹路径 +image_dir = r"g:\smoke\smoke_old\images" # 图片文件夹路径 +label_dir = r"g:\smoke\smoke_old\labels" # 标签文件夹路径 + +# 获取图片文件名(不包括扩展名) +image_files = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))} + +# 遍历标签文件夹 +for label_file in os.listdir(label_dir): + label_name, ext = os.path.splitext(label_file) + + if ext.lower() == '.txt' and label_name not in image_files: + # 如果标签文件没有对应的图片,删除该标签文件 + label_path = os.path.join(label_dir, label_file) + os.remove(label_path) + print(f"Deleted: {label_file}") diff --git a/Ai_tottle/ai_tottle_api.py b/Ai_tottle/ai_tottle_api.py index eaf1d88..583f158 100644 --- a/Ai_tottle/ai_tottle_api.py +++ b/Ai_tottle/ai_tottle_api.py @@ -316,3 +316,4 @@ async def yolo_train_progress(request, project_name): if __name__ == '__main__': app.run(host="0.0.0.0", port=12366, debug=True,workers=1) + \ No newline at end of file diff --git a/Ai_tottle/yolo_train.py b/Ai_tottle/yolo_train.py index 1e93c61..e654740 100644 --- a/Ai_tottle/yolo_train.py +++ b/Ai_tottle/yolo_train.py @@ -61,6 +61,7 @@ import pandas as pd import logging from tqdm import tqdm import miniohelp as miniohelp +from aboutdataset.download_oss import download_and_save_images_from_oss ######################################## Logging ######################################## def setup_logger(project: str): @@ -82,8 +83,12 @@ def setup_logger(project: str): return logger -def get_last_model_from_log(project: str, default_model: str): - """从日志中解析上一次训练的 last.pt 路径""" +def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"): + """ + 从日志解析上一次训练的 last.pt 路径 + 如果找不到则返回 default_model + 支持 default_model 为 .pt 或 .yaml + """ log_file = os.path.join("logs", f"{project}.log") if not os.path.exists(log_file): return default_model @@ -96,8 +101,23 @@ def get_last_model_from_log(project: str, default_model: str): path = line.strip().split("Saved last model path:")[-1].strip() if os.path.exists(path): return path + return default_model +def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name): + """ + yaml_name='config' + where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'" + image_dir='images' + label_dir='labels' + table_name = 'aidataset' + + Returns: + (image_dir, label_dir) + """ + download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name) + return image_dir, label_dir + ####################################### 工具函数 ####################################### def count_labels_by_class(label_dir): class_counter = Counter() @@ -338,19 +358,37 @@ def auto_train(db_host, db_database, db_user, db_password, db_port, model_id, logger.info("训练流程执行完成") -####################################### 主入口 ####################################### -if __name__ == '__main__': +def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name): + + imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name) auto_train( - db_host='222.212.85.86', - db_database='your_database_name', - db_user='postgres', - db_password='postgres', - db_port='5432', - model_id='best.pt', - img_path='./dataset/images', - label_path='./dataset/labels', + db_host=db_host, + db_database=db_database, + db_user=db_user, + db_password=db_password, + db_port=db_port, + model_id=model_id, + imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递 + label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递 new_path='./datasets', split_list=[0.7, 0.2, 0.1], class_names={'0': 'human', '1': 'car'}, project_name='my_project' ) + +####################################### 主入口 ####################################### +if __name__ == '__main__': + down_and_train( + db_host='222.212.85.86', + db_database='your_database_name', + db_user='postgres', + db_password='postgres', + db_port='5432', + model_id='best.pt', + img_path='./dataset/images', #before broken img path + label_path='./dataset/labels',#before broken labels path + new_path='./datasets', #after broken path + split_list=[0.7, 0.2, 0.1], + class_names={'0': 'human', '1': 'car'}, + project_name='my_project' +) \ No newline at end of file