00000
This commit is contained in:
parent
23b990c811
commit
49f00429b9
1
Ai_tottle/aboutdataset/.codemap/main-panel.json
Normal file
1
Ai_tottle/aboutdataset/.codemap/main-panel.json
Normal file
@ -0,0 +1 @@
|
||||
[]
|
||||
5
Ai_tottle/aboutdataset/.vscode/settings.json
vendored
Normal file
5
Ai_tottle/aboutdataset/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
BIN
Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc
Normal file
BIN
Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
13
Ai_tottle/aboutdataset/config-test.yaml
Normal file
13
Ai_tottle/aboutdataset/config-test.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
minio:
|
||||
endpoint: "222.212.85.86:9000"
|
||||
access_key: "adminjdskfj"
|
||||
secret_key: "123456ksldjfal@Y"
|
||||
secure: false
|
||||
web: "http://222.212.85.86"
|
||||
|
||||
sql:
|
||||
host: '8.137.54.85'
|
||||
port: 5060
|
||||
dbname: 'postgres'
|
||||
user: 'postgres'
|
||||
password: 'root'
|
||||
13
Ai_tottle/aboutdataset/config.yaml
Normal file
13
Ai_tottle/aboutdataset/config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
minio:
|
||||
endpoint: "222.212.85.86:9000"
|
||||
access_key: "adminjdskfj"
|
||||
secret_key: "123456ksldjfal@Y"
|
||||
secure: false
|
||||
web: "http://222.212.85.86"
|
||||
|
||||
sql:
|
||||
host: '222.212.85.86'
|
||||
port: 5432
|
||||
dbname: 'smart_dev'
|
||||
user: 'postgres'
|
||||
password: 'root'
|
||||
13
Ai_tottle/aboutdataset/config_test_dev.yaml
Normal file
13
Ai_tottle/aboutdataset/config_test_dev.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
minio:
|
||||
endpoint: "222.212.85.86:9000"
|
||||
access_key: "adminjdskfj"
|
||||
secret_key: "123456ksldjfal@Y"
|
||||
secure: false
|
||||
web: "http://222.212.85.86"
|
||||
|
||||
sql:
|
||||
host: '8.137.54.85'
|
||||
port: 5060
|
||||
dbname: 'smart_dev_123'
|
||||
user: 'postgres'
|
||||
password: 'root'
|
||||
103
Ai_tottle/aboutdataset/download_oss.py
Normal file
103
Ai_tottle/aboutdataset/download_oss.py
Normal file
@ -0,0 +1,103 @@
|
||||
import psycopg2
|
||||
import os
|
||||
import requests
|
||||
from miniohelp import read_sql_config
|
||||
import shutil
|
||||
|
||||
def fetch_object_and_labels(sql_config, where_clause=None, table_name='aidataset'):
|
||||
conn = psycopg2.connect(
|
||||
host=sql_config['host'],
|
||||
port=sql_config['port'],
|
||||
user=sql_config['user'],
|
||||
password=sql_config['password'],
|
||||
dbname=sql_config.get('database') or sql_config.get('dbname')
|
||||
)
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
sql = f"SELECT * FROM public.{table_name}"
|
||||
if where_clause:
|
||||
sql += f" WHERE {where_clause}"
|
||||
|
||||
cursor.execute(sql)
|
||||
rows = cursor.fetchall()
|
||||
return rows
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def download_file_from_url(url, save_path):
|
||||
"""
|
||||
从 URL 下载文件到指定路径
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
for chunk in response.iter_content(8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
print(f"✅ 下载成功:{save_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ 下载失败: {url} 错误: {e}")
|
||||
return False
|
||||
|
||||
def save_label(label_text, file_path):
|
||||
"""
|
||||
保存 label 内容到 .txt 文件
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(label_text.strip())
|
||||
print(f"✅ 标签已保存:{file_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存标签失败:{file_path} 错误: {e}")
|
||||
|
||||
def move_txt_files(src_dir='image', dst_dir='label'):
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
for filename in os.listdir(src_dir):
|
||||
if filename.endswith('.txt'):
|
||||
src_path = os.path.join(src_dir, filename)
|
||||
dst_path = os.path.join(dst_dir, filename)
|
||||
shutil.move(src_path, dst_path)
|
||||
print(f"已移动: {src_path} -> {dst_path}")
|
||||
|
||||
|
||||
def download_and_save_images_from_oss(yaml_name, where_clause, image_dir, label_dir, table_name):
|
||||
sql_config = read_sql_config(yaml_name)
|
||||
|
||||
rows = fetch_object_and_labels(sql_config, where_clause, table_name)
|
||||
print(f"共查询到 {len(rows)} 条记录")
|
||||
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
os.makedirs(label_dir, exist_ok=True)
|
||||
|
||||
for id, orgcode, model, state, objectname, label in rows:
|
||||
if label is None or label.strip() == "":
|
||||
# 跳过无标签的
|
||||
print(f"跳过无标签记录 id={id}")
|
||||
continue
|
||||
|
||||
full_url = f"{objectname}"
|
||||
filename = os.path.basename(objectname)
|
||||
name_no_ext = os.path.splitext(filename)[0]
|
||||
|
||||
image_path = os.path.join(image_dir, filename)
|
||||
label_path = os.path.join(label_dir, name_no_ext + ".txt")
|
||||
|
||||
success = download_file_from_url(full_url, image_path)
|
||||
if success:
|
||||
save_label(label, label_path)
|
||||
print("下载完成/n 下载完成")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_name='config'
|
||||
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
|
||||
image_dir='images'
|
||||
label_dir='labels'
|
||||
table_name = 'aidataset'
|
||||
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
|
||||
168
Ai_tottle/aboutdataset/miniohelp.py
Normal file
168
Ai_tottle/aboutdataset/miniohelp.py
Normal file
@ -0,0 +1,168 @@
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
# 读取配置并初始化 MinIO 客户端
|
||||
def load_config(yaml_name):
|
||||
yaml_path = f"{yaml_name}.yaml"
|
||||
# 读取 YAML 配置文件
|
||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
# 获取 MinIO 配置
|
||||
minio_config = config["minio"]
|
||||
|
||||
# 初始化 MinIO 客户端
|
||||
minio_client = Minio(
|
||||
endpoint=minio_config["endpoint"],
|
||||
access_key=minio_config["access_key"],
|
||||
secret_key=minio_config["secret_key"],
|
||||
secure=minio_config.get("secure", False)
|
||||
)
|
||||
|
||||
return minio_client
|
||||
|
||||
def read_sql_config(yaml_name):
|
||||
yaml_path = f"{yaml_name}.yaml"
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
sql_config = config.get('sql')
|
||||
if not sql_config:
|
||||
raise ValueError("未找到 'sql' 配置块")
|
||||
|
||||
# 优先使用 database,没有则用 dbname
|
||||
if 'database' not in sql_config and 'dbname' not in sql_config:
|
||||
raise KeyError("配置文件中缺少 'database' 或 'dbname' 字段")
|
||||
|
||||
return sql_config
|
||||
|
||||
|
||||
def parse_minio_url(minio_url):
|
||||
"""解析 MinIO 完整 URL,返回 bucket_name 和 object_name"""
|
||||
path = urlparse(minio_url).path.lstrip("/")
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError("URL 格式错误,无法提取 bucket 和 object")
|
||||
return parts[0], parts[1]
|
||||
|
||||
def create_bucket(client):
|
||||
"""访问 MinIO 服务器,打印存储桶"""
|
||||
try:
|
||||
buckets = client.list_buckets()
|
||||
for bucket in buckets:
|
||||
print(f"Bucket: {bucket.name}, Created: {bucket.creation_date}")
|
||||
except S3Error as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
def downFile(client, object_name, bucket_name, local_path=None):
|
||||
"""下载文件,可指定 bucket 和本地保存路径"""
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
download_path = local_path or os.path.join(current_directory, os.path.basename(object_name))
|
||||
os.makedirs(os.path.dirname(download_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
client.fget_object(bucket_name, object_name, download_path)
|
||||
print(f"✅ 文件已成功下载到 {download_path}")
|
||||
return download_path
|
||||
except S3Error as e:
|
||||
print(f"❌ 下载文件时出错: {e}")
|
||||
return None
|
||||
|
||||
def upload_file(client, file_path, bucket_name, bucket_directory):
|
||||
"""上传文件到指定 MinIO 存储桶和目录"""
|
||||
try:
|
||||
if not client.bucket_exists(bucket_name):
|
||||
print(f"❌ 存储桶 {bucket_name} 不存在")
|
||||
return None
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
object_name = f"{bucket_directory.rstrip('/')}/{file_name}" if bucket_directory else file_name
|
||||
|
||||
client.fput_object(bucket_name, object_name, file_path)
|
||||
print(f"✅ 文件已上传至 {bucket_name}/{object_name}")
|
||||
return f"{object_name}"
|
||||
except S3Error as e:
|
||||
print(f"❌ 上传文件时出错: {e}")
|
||||
return None
|
||||
|
||||
def upload_file_t(client, file_path, time_str, bucket_name, bucket_directory):
|
||||
"""上传文件到指定 MinIO 存储桶和目录(带时间路径)"""
|
||||
try:
|
||||
if not client.bucket_exists(bucket_name):
|
||||
print(f"❌ 存储桶 {bucket_name} 不存在")
|
||||
return None
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
object_name = f"{bucket_directory.rstrip('/')}/{time_str}/{file_name}" if bucket_directory else file_name
|
||||
|
||||
client.fput_object(bucket_name, object_name, file_path)
|
||||
print(f"✅ 文件已上传至 {bucket_name}/{object_name}")
|
||||
return f"{bucket_name}/{object_name}"
|
||||
except S3Error as e:
|
||||
print(f"❌ 上传文件时出错: {e}")
|
||||
return None
|
||||
|
||||
def upload_folder(client, folder_path, bucket_name, bucket_directory):
|
||||
"""上传整个文件夹到指定 MinIO 存储桶目录"""
|
||||
try:
|
||||
if not client.bucket_exists(bucket_name):
|
||||
print(f"❌ 存储桶 {bucket_name} 不存在")
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(folder_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
object_name = os.path.relpath(file_path, folder_path).replace("\\", "/")
|
||||
|
||||
if bucket_directory:
|
||||
object_name = f"{bucket_directory.rstrip('/')}/{object_name}"
|
||||
|
||||
client.fput_object(bucket_name, object_name, file_path)
|
||||
print(f"✅ 文件 {file_path} 已上传至 {bucket_name}/{object_name}")
|
||||
except S3Error as e:
|
||||
print(f"❌ 上传文件夹时出错: {e}")
|
||||
|
||||
def download_file_url(url: str, save_path: str):
|
||||
"""
|
||||
从指定 URL 下载文件并保存到本地路径。
|
||||
|
||||
:param url: 文件的 URL。
|
||||
:param save_path: 本地保存路径(包括文件名)。
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # 如果请求失败会抛出异常
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
print(f"文件已成功下载到: {save_path}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"下载失败: {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 初始化 MinIO 客户端
|
||||
client = Minio(
|
||||
endpoint="222.212.85.86",
|
||||
access_key="adminjdskfj",
|
||||
secret_key="123456ksldjfal@Y",
|
||||
secure=False
|
||||
)
|
||||
|
||||
# 示例:通过完整 URL 下载文件
|
||||
full_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/media/86961d3a-8790-4bc0-bfc6-dc45a8e4b9bd/DJI_202504291707_001_86961d3a-8790-4bc0-bfc6-dc45a8e4b9bd/DJI_20250429170933_0001_V.jpeg"
|
||||
bucket_name, object_name = parse_minio_url(full_url)
|
||||
downFile(client, object_name, bucket_name)
|
||||
|
||||
# 示例:上传文件
|
||||
file_path = r"D:\work\develop\sample.jpg" # 修改为实际文件
|
||||
upload_file(client, file_path, "300bdf2b-a150-406e-be63-d28bd29b409f", "media/output_video")
|
||||
|
||||
# 示例:上传文件夹
|
||||
folder_path = r"D:\work\develop"
|
||||
upload_folder(client, folder_path, "300bdf2b-a150-406e-be63-d28bd29b409f", "media/output_video")
|
||||
75
Ai_tottle/aboutdataset/pgadmin_helper.py
Normal file
75
Ai_tottle/aboutdataset/pgadmin_helper.py
Normal file
@ -0,0 +1,75 @@
|
||||
import psycopg2
|
||||
from datetime import datetime
|
||||
from miniohelp import *
|
||||
|
||||
def insert_data(table_name, data, table = 'danger'):
|
||||
"""
|
||||
向 PostgreSQL 数据库的指定表中插入数据
|
||||
:param table_name: 表名
|
||||
:param data: 字典格式的数据(列名: 值)
|
||||
"""
|
||||
# 验证表名是否合法
|
||||
if table_name not in [table]: # 在这里加入你的表名
|
||||
print(f"错误: 无效的表名 {table_name}")
|
||||
return
|
||||
|
||||
try:
|
||||
sql_yaml = data.pop('sql_yaml')
|
||||
# 数据库配置
|
||||
sql_config = read_sql_config(sql_yaml)
|
||||
db_host = sql_config['host']
|
||||
db_port = sql_config['port']
|
||||
db_database = sql_config['database']
|
||||
db_user = sql_config['user']
|
||||
db_password = sql_config['password']
|
||||
# 连接数据库
|
||||
with psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
database=db_database,
|
||||
user=db_user,
|
||||
password=db_password
|
||||
) as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# 生成 SQL 语句
|
||||
columns = ', '.join(data.keys())
|
||||
values = ', '.join(['%s'] * len(data))
|
||||
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({values})"
|
||||
|
||||
# 执行 SQL 语句
|
||||
cursor.execute(sql, tuple(data.values()))
|
||||
|
||||
# 提交事务
|
||||
conn.commit()
|
||||
print("数据插入成功")
|
||||
|
||||
except psycopg2.Error as e:
|
||||
print("数据库错误:", e.pgcode, e.pgerror) # 打印详细的错误码和错误信息
|
||||
|
||||
# 示例数据
|
||||
data = {
|
||||
'orgcode': 'bdzl',
|
||||
'sn': '123456',
|
||||
'snname': '测试机',
|
||||
'spaceid': '97123d28-40a3-4dd1-9021-616a7d60dce8',
|
||||
'spacename': '目标检测',
|
||||
'guid': '97123d28-40a3-4dd1-9021-616a7d60dce8',
|
||||
'level': '一般',
|
||||
'title': 'pedestrian',
|
||||
'dangertype': '目标',
|
||||
'content': '目标检测',
|
||||
'gps1': "",
|
||||
'dealresult': "",
|
||||
'imgobjectname': 'AIResults/20250320/97123d28-40a3-4dd1-9021-616a7d60dce8.jpg',
|
||||
'state': 0,
|
||||
'createtm': '2025-03-20 16:12:04',
|
||||
'dealtm': '2025-03-20 16:12:04',
|
||||
'createuser': "",
|
||||
'dealuser': "",
|
||||
'checkuser': "",
|
||||
'remarks': "",
|
||||
'sql_yaml': "", # 数据库配置文件路径
|
||||
}
|
||||
|
||||
# 调用插入函数
|
||||
insert_data('danger', data)
|
||||
138
Ai_tottle/aboutdataset/upload.py
Normal file
138
Ai_tottle/aboutdataset/upload.py
Normal file
@ -0,0 +1,138 @@
|
||||
import os
|
||||
import psycopg2
|
||||
from miniohelp import load_config, upload_file
|
||||
|
||||
# ------------------ 加载数据库配置 ------------------
|
||||
def load_sql_config(yaml_name):
|
||||
import yaml
|
||||
# 从 yaml 文件中读取 sql 配置
|
||||
with open(f"{yaml_name}.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config["sql"]
|
||||
|
||||
# ------------------ 插入单条数据到数据库 ------------------
|
||||
def insert_to_database(conn, table, data):
|
||||
# 定义插入 SQL
|
||||
sql = f"""
|
||||
INSERT INTO {table} (
|
||||
id, orgcode, model, state, objectname, label
|
||||
) VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
try:
|
||||
# 使用游标执行 SQL
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, data)
|
||||
# 提交事务
|
||||
conn.commit()
|
||||
print("✅ 数据插入成功")
|
||||
return True
|
||||
except psycopg2.Error as e:
|
||||
# 出现错误时回滚事务
|
||||
conn.rollback()
|
||||
print("❌ 插入失败:", e)
|
||||
return False
|
||||
|
||||
# ------------------ 主流程:上传图片 + 插入数据库 ------------------
|
||||
def upload_and_insert_images_with_labels(
|
||||
yaml_name, # 配置文件名(不带 .yaml 后缀)
|
||||
image_dir, # 图片目录
|
||||
label_dir, # 标签目录
|
||||
bucket_name, # MinIO 存储桶名
|
||||
bucket_path, # MinIO 上传路径(目录)
|
||||
table_name, # 数据库表名
|
||||
model_name, # 模型名,用于区分不同模型的数据
|
||||
orgcode="bdzl" # 机构代码,默认 bdzl
|
||||
):
|
||||
# 1. 初始化 MinIO 和 SQL 配置
|
||||
minio_client = load_config(yaml_name) # 加载 MinIO 配置
|
||||
sql_conf = load_sql_config(yaml_name) # 加载 SQL 配置
|
||||
|
||||
# 2. 创建数据库连接
|
||||
conn = psycopg2.connect(
|
||||
host=sql_conf["host"],
|
||||
port=sql_conf["port"],
|
||||
user=sql_conf["user"],
|
||||
password=sql_conf["password"],
|
||||
database=sql_conf["dbname"]
|
||||
)
|
||||
|
||||
# 3. 查询数据库中当前最大 id 和已存在的文件
|
||||
with conn.cursor() as cur:
|
||||
# 查询表里当前最大 id
|
||||
cur.execute(f"SELECT COALESCE(MAX(id), 0) FROM {table_name}")
|
||||
max_id = cur.fetchone()[0] or 0
|
||||
next_id = max_id + 1
|
||||
|
||||
# 查询该模型已有的文件名(仅文件名部分)
|
||||
cur.execute(f"SELECT objectname FROM {table_name} WHERE model = %s", (model_name,))
|
||||
existing_files = set([os.path.basename(row[0]) for row in cur.fetchall()])
|
||||
|
||||
# 4. 遍历图片目录
|
||||
for filename in os.listdir(image_dir):
|
||||
# 只处理图片类型文件
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
|
||||
continue
|
||||
|
||||
# 如果数据库已存在该文件,跳过
|
||||
if filename in existing_files:
|
||||
print(f"⏭️ 已存在于数据库中,跳过:{filename}")
|
||||
continue
|
||||
|
||||
image_path = os.path.join(image_dir, filename)
|
||||
name_no_ext = os.path.splitext(filename)[0]
|
||||
label_path = os.path.join(label_dir, name_no_ext + ".txt")
|
||||
|
||||
# 标签文件不存在,跳过
|
||||
if not os.path.exists(label_path):
|
||||
print(f"⚠️ 未找到标签:{label_path},跳过")
|
||||
continue
|
||||
|
||||
# 读取标签文件内容
|
||||
with open(label_path, "r", encoding="utf-8") as f:
|
||||
label_content = f.read().strip()
|
||||
|
||||
# 5. 上传图片到 MinIO
|
||||
objectname = upload_file(minio_client, image_path, bucket_name, bucket_path)
|
||||
if not objectname:
|
||||
print(f"❌ 上传失败:{filename}")
|
||||
continue
|
||||
|
||||
# 6. 插入数据库
|
||||
try:
|
||||
data = (
|
||||
next_id, # id
|
||||
orgcode, # 机构代码
|
||||
model_name, # 模型
|
||||
1, # state 固定 1
|
||||
objectname, # MinIO 返回的文件路径
|
||||
label_content # 标签内容
|
||||
)
|
||||
info = insert_to_database(conn, table_name, data)
|
||||
print(info)
|
||||
print(f"✅ 成功插入数据库 (id={next_id}):{filename}")
|
||||
next_id += 1
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库插入失败:{filename}, error: {e}")
|
||||
|
||||
# 7. 提交事务并关闭连接
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("🎉 所有图片和标签处理完成!")
|
||||
|
||||
# ------------------ 示例调用 ------------------
|
||||
if __name__ == "__main__":
|
||||
yaml_name = "config_test_dev" # 配置文件名(不含 .yaml 后缀)
|
||||
image_dir = r"D:\dataset\images\train" # 图片路径
|
||||
label_dir = r"D:\dataset\labels\train" # 标签路径
|
||||
bucket_name = "300bdf2b-a150-406e-be63-d28bd29b409f" # MinIO 桶名
|
||||
bucket_directory = "new_datasets/fence" # MinIO 内的目录
|
||||
table_name = "public.aidataset" # 数据库表名
|
||||
model_name = "08ff91fd-60d2-470f-9675-b18800229654" # 模型uuid
|
||||
orgcode = "bdzl" # 机构代码
|
||||
|
||||
# 调用主函数
|
||||
upload_and_insert_images_with_labels(
|
||||
yaml_name, image_dir, label_dir,
|
||||
bucket_name, bucket_directory,
|
||||
table_name, model_name, orgcode
|
||||
)
|
||||
18
Ai_tottle/aboutdataset/删除无对应图片的txt.py
Normal file
18
Ai_tottle/aboutdataset/删除无对应图片的txt.py
Normal file
@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
# 设置图片和标签文件夹路径
|
||||
image_dir = r"g:\smoke\smoke_old\images" # 图片文件夹路径
|
||||
label_dir = r"g:\smoke\smoke_old\labels" # 标签文件夹路径
|
||||
|
||||
# 获取图片文件名(不包括扩展名)
|
||||
image_files = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))}
|
||||
|
||||
# 遍历标签文件夹
|
||||
for label_file in os.listdir(label_dir):
|
||||
label_name, ext = os.path.splitext(label_file)
|
||||
|
||||
if ext.lower() == '.txt' and label_name not in image_files:
|
||||
# 如果标签文件没有对应的图片,删除该标签文件
|
||||
label_path = os.path.join(label_dir, label_file)
|
||||
os.remove(label_path)
|
||||
print(f"Deleted: {label_file}")
|
||||
@ -316,3 +316,4 @@ async def yolo_train_progress(request, project_name):
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
||||
|
||||
@ -61,6 +61,7 @@ import pandas as pd
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import miniohelp as miniohelp
|
||||
from aboutdataset.download_oss import download_and_save_images_from_oss
|
||||
|
||||
######################################## Logging ########################################
|
||||
def setup_logger(project: str):
|
||||
@ -82,8 +83,12 @@ def setup_logger(project: str):
|
||||
|
||||
return logger
|
||||
|
||||
def get_last_model_from_log(project: str, default_model: str):
|
||||
"""从日志中解析上一次训练的 last.pt 路径"""
|
||||
def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"):
|
||||
"""
|
||||
从日志解析上一次训练的 last.pt 路径
|
||||
如果找不到则返回 default_model
|
||||
支持 default_model 为 .pt 或 .yaml
|
||||
"""
|
||||
log_file = os.path.join("logs", f"{project}.log")
|
||||
if not os.path.exists(log_file):
|
||||
return default_model
|
||||
@ -96,8 +101,23 @@ def get_last_model_from_log(project: str, default_model: str):
|
||||
path = line.strip().split("Saved last model path:")[-1].strip()
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
|
||||
return default_model
|
||||
|
||||
def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name):
|
||||
"""
|
||||
yaml_name='config'
|
||||
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
|
||||
image_dir='images'
|
||||
label_dir='labels'
|
||||
table_name = 'aidataset'
|
||||
|
||||
Returns:
|
||||
(image_dir, label_dir)
|
||||
"""
|
||||
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
|
||||
return image_dir, label_dir
|
||||
|
||||
####################################### 工具函数 #######################################
|
||||
def count_labels_by_class(label_dir):
|
||||
class_counter = Counter()
|
||||
@ -338,18 +358,36 @@ def auto_train(db_host, db_database, db_user, db_password, db_port, model_id,
|
||||
|
||||
logger.info("训练流程执行完成")
|
||||
|
||||
def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name):
|
||||
|
||||
imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name)
|
||||
auto_train(
|
||||
db_host=db_host,
|
||||
db_database=db_database,
|
||||
db_user=db_user,
|
||||
db_password=db_password,
|
||||
db_port=db_port,
|
||||
model_id=model_id,
|
||||
imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递
|
||||
label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递
|
||||
new_path='./datasets',
|
||||
split_list=[0.7, 0.2, 0.1],
|
||||
class_names={'0': 'human', '1': 'car'},
|
||||
project_name='my_project'
|
||||
)
|
||||
|
||||
####################################### 主入口 #######################################
|
||||
if __name__ == '__main__':
|
||||
auto_train(
|
||||
down_and_train(
|
||||
db_host='222.212.85.86',
|
||||
db_database='your_database_name',
|
||||
db_user='postgres',
|
||||
db_password='postgres',
|
||||
db_port='5432',
|
||||
model_id='best.pt',
|
||||
img_path='./dataset/images',
|
||||
label_path='./dataset/labels',
|
||||
new_path='./datasets',
|
||||
img_path='./dataset/images', #before broken img path
|
||||
label_path='./dataset/labels',#before broken labels path
|
||||
new_path='./datasets', #after broken path
|
||||
split_list=[0.7, 0.2, 0.1],
|
||||
class_names={'0': 'human', '1': 'car'},
|
||||
project_name='my_project'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user