00000
This commit is contained in:
parent
23b990c811
commit
49f00429b9
1
Ai_tottle/aboutdataset/.codemap/main-panel.json
Normal file
1
Ai_tottle/aboutdataset/.codemap/main-panel.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
[]
|
||||||
5
Ai_tottle/aboutdataset/.vscode/settings.json
vendored
Normal file
5
Ai_tottle/aboutdataset/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||||
|
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||||
|
"python-envs.pythonProjects": []
|
||||||
|
}
|
||||||
BIN
Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc
Normal file
BIN
Ai_tottle/aboutdataset/__pycache__/miniohelp.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
13
Ai_tottle/aboutdataset/config-test.yaml
Normal file
13
Ai_tottle/aboutdataset/config-test.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
minio:
|
||||||
|
endpoint: "222.212.85.86:9000"
|
||||||
|
access_key: "adminjdskfj"
|
||||||
|
secret_key: "123456ksldjfal@Y"
|
||||||
|
secure: false
|
||||||
|
web: "http://222.212.85.86"
|
||||||
|
|
||||||
|
sql:
|
||||||
|
host: '8.137.54.85'
|
||||||
|
port: 5060
|
||||||
|
dbname: 'postgres'
|
||||||
|
user: 'postgres'
|
||||||
|
password: 'root'
|
||||||
13
Ai_tottle/aboutdataset/config.yaml
Normal file
13
Ai_tottle/aboutdataset/config.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
minio:
|
||||||
|
endpoint: "222.212.85.86:9000"
|
||||||
|
access_key: "adminjdskfj"
|
||||||
|
secret_key: "123456ksldjfal@Y"
|
||||||
|
secure: false
|
||||||
|
web: "http://222.212.85.86"
|
||||||
|
|
||||||
|
sql:
|
||||||
|
host: '222.212.85.86'
|
||||||
|
port: 5432
|
||||||
|
dbname: 'smart_dev'
|
||||||
|
user: 'postgres'
|
||||||
|
password: 'root'
|
||||||
13
Ai_tottle/aboutdataset/config_test_dev.yaml
Normal file
13
Ai_tottle/aboutdataset/config_test_dev.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
minio:
|
||||||
|
endpoint: "222.212.85.86:9000"
|
||||||
|
access_key: "adminjdskfj"
|
||||||
|
secret_key: "123456ksldjfal@Y"
|
||||||
|
secure: false
|
||||||
|
web: "http://222.212.85.86"
|
||||||
|
|
||||||
|
sql:
|
||||||
|
host: '8.137.54.85'
|
||||||
|
port: 5060
|
||||||
|
dbname: 'smart_dev_123'
|
||||||
|
user: 'postgres'
|
||||||
|
password: 'root'
|
||||||
103
Ai_tottle/aboutdataset/download_oss.py
Normal file
103
Ai_tottle/aboutdataset/download_oss.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import psycopg2
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from miniohelp import read_sql_config
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def fetch_object_and_labels(sql_config, where_clause=None, table_name='aidataset'):
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=sql_config['host'],
|
||||||
|
port=sql_config['port'],
|
||||||
|
user=sql_config['user'],
|
||||||
|
password=sql_config['password'],
|
||||||
|
dbname=sql_config.get('database') or sql_config.get('dbname')
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
sql = f"SELECT * FROM public.{table_name}"
|
||||||
|
if where_clause:
|
||||||
|
sql += f" WHERE {where_clause}"
|
||||||
|
|
||||||
|
cursor.execute(sql)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
return rows
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def download_file_from_url(url, save_path):
|
||||||
|
"""
|
||||||
|
从 URL 下载文件到指定路径
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
print(f"✅ 下载成功:{save_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 下载失败: {url} 错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_label(label_text, file_path):
|
||||||
|
"""
|
||||||
|
保存 label 内容到 .txt 文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(label_text.strip())
|
||||||
|
print(f"✅ 标签已保存:{file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 保存标签失败:{file_path} 错误: {e}")
|
||||||
|
|
||||||
|
def move_txt_files(src_dir='image', dst_dir='label'):
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
for filename in os.listdir(src_dir):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
src_path = os.path.join(src_dir, filename)
|
||||||
|
dst_path = os.path.join(dst_dir, filename)
|
||||||
|
shutil.move(src_path, dst_path)
|
||||||
|
print(f"已移动: {src_path} -> {dst_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_save_images_from_oss(yaml_name, where_clause, image_dir, label_dir, table_name):
|
||||||
|
sql_config = read_sql_config(yaml_name)
|
||||||
|
|
||||||
|
rows = fetch_object_and_labels(sql_config, where_clause, table_name)
|
||||||
|
print(f"共查询到 {len(rows)} 条记录")
|
||||||
|
|
||||||
|
os.makedirs(image_dir, exist_ok=True)
|
||||||
|
os.makedirs(label_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for id, orgcode, model, state, objectname, label in rows:
|
||||||
|
if label is None or label.strip() == "":
|
||||||
|
# 跳过无标签的
|
||||||
|
print(f"跳过无标签记录 id={id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_url = f"{objectname}"
|
||||||
|
filename = os.path.basename(objectname)
|
||||||
|
name_no_ext = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
image_path = os.path.join(image_dir, filename)
|
||||||
|
label_path = os.path.join(label_dir, name_no_ext + ".txt")
|
||||||
|
|
||||||
|
success = download_file_from_url(full_url, image_path)
|
||||||
|
if success:
|
||||||
|
save_label(label, label_path)
|
||||||
|
print("下载完成/n 下载完成")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
yaml_name='config'
|
||||||
|
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
|
||||||
|
image_dir='images'
|
||||||
|
label_dir='labels'
|
||||||
|
table_name = 'aidataset'
|
||||||
|
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
|
||||||
168
Ai_tottle/aboutdataset/miniohelp.py
Normal file
168
Ai_tottle/aboutdataset/miniohelp.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
from minio import Minio
|
||||||
|
from minio.error import S3Error
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# 读取配置并初始化 MinIO 客户端
|
||||||
|
def load_config(yaml_name):
|
||||||
|
yaml_path = f"{yaml_name}.yaml"
|
||||||
|
# 读取 YAML 配置文件
|
||||||
|
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
# 获取 MinIO 配置
|
||||||
|
minio_config = config["minio"]
|
||||||
|
|
||||||
|
# 初始化 MinIO 客户端
|
||||||
|
minio_client = Minio(
|
||||||
|
endpoint=minio_config["endpoint"],
|
||||||
|
access_key=minio_config["access_key"],
|
||||||
|
secret_key=minio_config["secret_key"],
|
||||||
|
secure=minio_config.get("secure", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return minio_client
|
||||||
|
|
||||||
|
def read_sql_config(yaml_name):
|
||||||
|
yaml_path = f"{yaml_name}.yaml"
|
||||||
|
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
sql_config = config.get('sql')
|
||||||
|
if not sql_config:
|
||||||
|
raise ValueError("未找到 'sql' 配置块")
|
||||||
|
|
||||||
|
# 优先使用 database,没有则用 dbname
|
||||||
|
if 'database' not in sql_config and 'dbname' not in sql_config:
|
||||||
|
raise KeyError("配置文件中缺少 'database' 或 'dbname' 字段")
|
||||||
|
|
||||||
|
return sql_config
|
||||||
|
|
||||||
|
|
||||||
|
def parse_minio_url(minio_url):
|
||||||
|
"""解析 MinIO 完整 URL,返回 bucket_name 和 object_name"""
|
||||||
|
path = urlparse(minio_url).path.lstrip("/")
|
||||||
|
parts = path.split("/", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError("URL 格式错误,无法提取 bucket 和 object")
|
||||||
|
return parts[0], parts[1]
|
||||||
|
|
||||||
|
def create_bucket(client):
|
||||||
|
"""访问 MinIO 服务器,打印存储桶"""
|
||||||
|
try:
|
||||||
|
buckets = client.list_buckets()
|
||||||
|
for bucket in buckets:
|
||||||
|
print(f"Bucket: {bucket.name}, Created: {bucket.creation_date}")
|
||||||
|
except S3Error as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
def downFile(client, object_name, bucket_name, local_path=None):
|
||||||
|
"""下载文件,可指定 bucket 和本地保存路径"""
|
||||||
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
download_path = local_path or os.path.join(current_directory, os.path.basename(object_name))
|
||||||
|
os.makedirs(os.path.dirname(download_path), exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.fget_object(bucket_name, object_name, download_path)
|
||||||
|
print(f"✅ 文件已成功下载到 {download_path}")
|
||||||
|
return download_path
|
||||||
|
except S3Error as e:
|
||||||
|
print(f"❌ 下载文件时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_file(client, file_path, bucket_name, bucket_directory):
|
||||||
|
"""上传文件到指定 MinIO 存储桶和目录"""
|
||||||
|
try:
|
||||||
|
if not client.bucket_exists(bucket_name):
|
||||||
|
print(f"❌ 存储桶 {bucket_name} 不存在")
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
object_name = f"{bucket_directory.rstrip('/')}/{file_name}" if bucket_directory else file_name
|
||||||
|
|
||||||
|
client.fput_object(bucket_name, object_name, file_path)
|
||||||
|
print(f"✅ 文件已上传至 {bucket_name}/{object_name}")
|
||||||
|
return f"{object_name}"
|
||||||
|
except S3Error as e:
|
||||||
|
print(f"❌ 上传文件时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_file_t(client, file_path, time_str, bucket_name, bucket_directory):
|
||||||
|
"""上传文件到指定 MinIO 存储桶和目录(带时间路径)"""
|
||||||
|
try:
|
||||||
|
if not client.bucket_exists(bucket_name):
|
||||||
|
print(f"❌ 存储桶 {bucket_name} 不存在")
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
object_name = f"{bucket_directory.rstrip('/')}/{time_str}/{file_name}" if bucket_directory else file_name
|
||||||
|
|
||||||
|
client.fput_object(bucket_name, object_name, file_path)
|
||||||
|
print(f"✅ 文件已上传至 {bucket_name}/{object_name}")
|
||||||
|
return f"{bucket_name}/{object_name}"
|
||||||
|
except S3Error as e:
|
||||||
|
print(f"❌ 上传文件时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_folder(client, folder_path, bucket_name, bucket_directory):
|
||||||
|
"""上传整个文件夹到指定 MinIO 存储桶目录"""
|
||||||
|
try:
|
||||||
|
if not client.bucket_exists(bucket_name):
|
||||||
|
print(f"❌ 存储桶 {bucket_name} 不存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
for root, _, files in os.walk(folder_path):
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
object_name = os.path.relpath(file_path, folder_path).replace("\\", "/")
|
||||||
|
|
||||||
|
if bucket_directory:
|
||||||
|
object_name = f"{bucket_directory.rstrip('/')}/{object_name}"
|
||||||
|
|
||||||
|
client.fput_object(bucket_name, object_name, file_path)
|
||||||
|
print(f"✅ 文件 {file_path} 已上传至 {bucket_name}/{object_name}")
|
||||||
|
except S3Error as e:
|
||||||
|
print(f"❌ 上传文件夹时出错: {e}")
|
||||||
|
|
||||||
|
def download_file_url(url: str, save_path: str):
|
||||||
|
"""
|
||||||
|
从指定 URL 下载文件并保存到本地路径。
|
||||||
|
|
||||||
|
:param url: 文件的 URL。
|
||||||
|
:param save_path: 本地保存路径(包括文件名)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status() # 如果请求失败会抛出异常
|
||||||
|
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
print(f"文件已成功下载到: {save_path}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"下载失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 初始化 MinIO 客户端
|
||||||
|
client = Minio(
|
||||||
|
endpoint="222.212.85.86",
|
||||||
|
access_key="adminjdskfj",
|
||||||
|
secret_key="123456ksldjfal@Y",
|
||||||
|
secure=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 示例:通过完整 URL 下载文件
|
||||||
|
full_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/media/86961d3a-8790-4bc0-bfc6-dc45a8e4b9bd/DJI_202504291707_001_86961d3a-8790-4bc0-bfc6-dc45a8e4b9bd/DJI_20250429170933_0001_V.jpeg"
|
||||||
|
bucket_name, object_name = parse_minio_url(full_url)
|
||||||
|
downFile(client, object_name, bucket_name)
|
||||||
|
|
||||||
|
# 示例:上传文件
|
||||||
|
file_path = r"D:\work\develop\sample.jpg" # 修改为实际文件
|
||||||
|
upload_file(client, file_path, "300bdf2b-a150-406e-be63-d28bd29b409f", "media/output_video")
|
||||||
|
|
||||||
|
# 示例:上传文件夹
|
||||||
|
folder_path = r"D:\work\develop"
|
||||||
|
upload_folder(client, folder_path, "300bdf2b-a150-406e-be63-d28bd29b409f", "media/output_video")
|
||||||
75
Ai_tottle/aboutdataset/pgadmin_helper.py
Normal file
75
Ai_tottle/aboutdataset/pgadmin_helper.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import psycopg2
|
||||||
|
from datetime import datetime
|
||||||
|
from miniohelp import *
|
||||||
|
|
||||||
|
def insert_data(table_name, data, table = 'danger'):
|
||||||
|
"""
|
||||||
|
向 PostgreSQL 数据库的指定表中插入数据
|
||||||
|
:param table_name: 表名
|
||||||
|
:param data: 字典格式的数据(列名: 值)
|
||||||
|
"""
|
||||||
|
# 验证表名是否合法
|
||||||
|
if table_name not in [table]: # 在这里加入你的表名
|
||||||
|
print(f"错误: 无效的表名 {table_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
sql_yaml = data.pop('sql_yaml')
|
||||||
|
# 数据库配置
|
||||||
|
sql_config = read_sql_config(sql_yaml)
|
||||||
|
db_host = sql_config['host']
|
||||||
|
db_port = sql_config['port']
|
||||||
|
db_database = sql_config['database']
|
||||||
|
db_user = sql_config['user']
|
||||||
|
db_password = sql_config['password']
|
||||||
|
# 连接数据库
|
||||||
|
with psycopg2.connect(
|
||||||
|
host=db_host,
|
||||||
|
port=db_port,
|
||||||
|
database=db_database,
|
||||||
|
user=db_user,
|
||||||
|
password=db_password
|
||||||
|
) as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
# 生成 SQL 语句
|
||||||
|
columns = ', '.join(data.keys())
|
||||||
|
values = ', '.join(['%s'] * len(data))
|
||||||
|
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({values})"
|
||||||
|
|
||||||
|
# 执行 SQL 语句
|
||||||
|
cursor.execute(sql, tuple(data.values()))
|
||||||
|
|
||||||
|
# 提交事务
|
||||||
|
conn.commit()
|
||||||
|
print("数据插入成功")
|
||||||
|
|
||||||
|
except psycopg2.Error as e:
|
||||||
|
print("数据库错误:", e.pgcode, e.pgerror) # 打印详细的错误码和错误信息
|
||||||
|
|
||||||
|
# 示例数据
|
||||||
|
data = {
|
||||||
|
'orgcode': 'bdzl',
|
||||||
|
'sn': '123456',
|
||||||
|
'snname': '测试机',
|
||||||
|
'spaceid': '97123d28-40a3-4dd1-9021-616a7d60dce8',
|
||||||
|
'spacename': '目标检测',
|
||||||
|
'guid': '97123d28-40a3-4dd1-9021-616a7d60dce8',
|
||||||
|
'level': '一般',
|
||||||
|
'title': 'pedestrian',
|
||||||
|
'dangertype': '目标',
|
||||||
|
'content': '目标检测',
|
||||||
|
'gps1': "",
|
||||||
|
'dealresult': "",
|
||||||
|
'imgobjectname': 'AIResults/20250320/97123d28-40a3-4dd1-9021-616a7d60dce8.jpg',
|
||||||
|
'state': 0,
|
||||||
|
'createtm': '2025-03-20 16:12:04',
|
||||||
|
'dealtm': '2025-03-20 16:12:04',
|
||||||
|
'createuser': "",
|
||||||
|
'dealuser': "",
|
||||||
|
'checkuser': "",
|
||||||
|
'remarks': "",
|
||||||
|
'sql_yaml': "", # 数据库配置文件路径
|
||||||
|
}
|
||||||
|
|
||||||
|
# 调用插入函数
|
||||||
|
insert_data('danger', data)
|
||||||
138
Ai_tottle/aboutdataset/upload.py
Normal file
138
Ai_tottle/aboutdataset/upload.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import os
|
||||||
|
import psycopg2
|
||||||
|
from miniohelp import load_config, upload_file
|
||||||
|
|
||||||
|
# ------------------ 加载数据库配置 ------------------
|
||||||
|
def load_sql_config(yaml_name):
|
||||||
|
import yaml
|
||||||
|
# 从 yaml 文件中读取 sql 配置
|
||||||
|
with open(f"{yaml_name}.yaml", "r", encoding="utf-8") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
return config["sql"]
|
||||||
|
|
||||||
|
# ------------------ 插入单条数据到数据库 ------------------
|
||||||
|
def insert_to_database(conn, table, data):
|
||||||
|
# 定义插入 SQL
|
||||||
|
sql = f"""
|
||||||
|
INSERT INTO {table} (
|
||||||
|
id, orgcode, model, state, objectname, label
|
||||||
|
) VALUES (%s, %s, %s, %s, %s, %s)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用游标执行 SQL
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(sql, data)
|
||||||
|
# 提交事务
|
||||||
|
conn.commit()
|
||||||
|
print("✅ 数据插入成功")
|
||||||
|
return True
|
||||||
|
except psycopg2.Error as e:
|
||||||
|
# 出现错误时回滚事务
|
||||||
|
conn.rollback()
|
||||||
|
print("❌ 插入失败:", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ------------------ 主流程:上传图片 + 插入数据库 ------------------
|
||||||
|
def upload_and_insert_images_with_labels(
|
||||||
|
yaml_name, # 配置文件名(不带 .yaml 后缀)
|
||||||
|
image_dir, # 图片目录
|
||||||
|
label_dir, # 标签目录
|
||||||
|
bucket_name, # MinIO 存储桶名
|
||||||
|
bucket_path, # MinIO 上传路径(目录)
|
||||||
|
table_name, # 数据库表名
|
||||||
|
model_name, # 模型名,用于区分不同模型的数据
|
||||||
|
orgcode="bdzl" # 机构代码,默认 bdzl
|
||||||
|
):
|
||||||
|
# 1. 初始化 MinIO 和 SQL 配置
|
||||||
|
minio_client = load_config(yaml_name) # 加载 MinIO 配置
|
||||||
|
sql_conf = load_sql_config(yaml_name) # 加载 SQL 配置
|
||||||
|
|
||||||
|
# 2. 创建数据库连接
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=sql_conf["host"],
|
||||||
|
port=sql_conf["port"],
|
||||||
|
user=sql_conf["user"],
|
||||||
|
password=sql_conf["password"],
|
||||||
|
database=sql_conf["dbname"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 查询数据库中当前最大 id 和已存在的文件
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
# 查询表里当前最大 id
|
||||||
|
cur.execute(f"SELECT COALESCE(MAX(id), 0) FROM {table_name}")
|
||||||
|
max_id = cur.fetchone()[0] or 0
|
||||||
|
next_id = max_id + 1
|
||||||
|
|
||||||
|
# 查询该模型已有的文件名(仅文件名部分)
|
||||||
|
cur.execute(f"SELECT objectname FROM {table_name} WHERE model = %s", (model_name,))
|
||||||
|
existing_files = set([os.path.basename(row[0]) for row in cur.fetchall()])
|
||||||
|
|
||||||
|
# 4. 遍历图片目录
|
||||||
|
for filename in os.listdir(image_dir):
|
||||||
|
# 只处理图片类型文件
|
||||||
|
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果数据库已存在该文件,跳过
|
||||||
|
if filename in existing_files:
|
||||||
|
print(f"⏭️ 已存在于数据库中,跳过:{filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_path = os.path.join(image_dir, filename)
|
||||||
|
name_no_ext = os.path.splitext(filename)[0]
|
||||||
|
label_path = os.path.join(label_dir, name_no_ext + ".txt")
|
||||||
|
|
||||||
|
# 标签文件不存在,跳过
|
||||||
|
if not os.path.exists(label_path):
|
||||||
|
print(f"⚠️ 未找到标签:{label_path},跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 读取标签文件内容
|
||||||
|
with open(label_path, "r", encoding="utf-8") as f:
|
||||||
|
label_content = f.read().strip()
|
||||||
|
|
||||||
|
# 5. 上传图片到 MinIO
|
||||||
|
objectname = upload_file(minio_client, image_path, bucket_name, bucket_path)
|
||||||
|
if not objectname:
|
||||||
|
print(f"❌ 上传失败:{filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 6. 插入数据库
|
||||||
|
try:
|
||||||
|
data = (
|
||||||
|
next_id, # id
|
||||||
|
orgcode, # 机构代码
|
||||||
|
model_name, # 模型
|
||||||
|
1, # state 固定 1
|
||||||
|
objectname, # MinIO 返回的文件路径
|
||||||
|
label_content # 标签内容
|
||||||
|
)
|
||||||
|
info = insert_to_database(conn, table_name, data)
|
||||||
|
print(info)
|
||||||
|
print(f"✅ 成功插入数据库 (id={next_id}):{filename}")
|
||||||
|
next_id += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 数据库插入失败:{filename}, error: {e}")
|
||||||
|
|
||||||
|
# 7. 提交事务并关闭连接
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
print("🎉 所有图片和标签处理完成!")
|
||||||
|
|
||||||
|
# ------------------ 示例调用 ------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
yaml_name = "config_test_dev" # 配置文件名(不含 .yaml 后缀)
|
||||||
|
image_dir = r"D:\dataset\images\train" # 图片路径
|
||||||
|
label_dir = r"D:\dataset\labels\train" # 标签路径
|
||||||
|
bucket_name = "300bdf2b-a150-406e-be63-d28bd29b409f" # MinIO 桶名
|
||||||
|
bucket_directory = "new_datasets/fence" # MinIO 内的目录
|
||||||
|
table_name = "public.aidataset" # 数据库表名
|
||||||
|
model_name = "08ff91fd-60d2-470f-9675-b18800229654" # 模型uuid
|
||||||
|
orgcode = "bdzl" # 机构代码
|
||||||
|
|
||||||
|
# 调用主函数
|
||||||
|
upload_and_insert_images_with_labels(
|
||||||
|
yaml_name, image_dir, label_dir,
|
||||||
|
bucket_name, bucket_directory,
|
||||||
|
table_name, model_name, orgcode
|
||||||
|
)
|
||||||
18
Ai_tottle/aboutdataset/删除无对应图片的txt.py
Normal file
18
Ai_tottle/aboutdataset/删除无对应图片的txt.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
# 设置图片和标签文件夹路径
|
||||||
|
image_dir = r"g:\smoke\smoke_old\images" # 图片文件夹路径
|
||||||
|
label_dir = r"g:\smoke\smoke_old\labels" # 标签文件夹路径
|
||||||
|
|
||||||
|
# 获取图片文件名(不包括扩展名)
|
||||||
|
image_files = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))}
|
||||||
|
|
||||||
|
# 遍历标签文件夹
|
||||||
|
for label_file in os.listdir(label_dir):
|
||||||
|
label_name, ext = os.path.splitext(label_file)
|
||||||
|
|
||||||
|
if ext.lower() == '.txt' and label_name not in image_files:
|
||||||
|
# 如果标签文件没有对应的图片,删除该标签文件
|
||||||
|
label_path = os.path.join(label_dir, label_file)
|
||||||
|
os.remove(label_path)
|
||||||
|
print(f"Deleted: {label_file}")
|
||||||
@ -316,3 +316,4 @@ async def yolo_train_progress(request, project_name):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
||||||
|
|
||||||
@ -61,6 +61,7 @@ import pandas as pd
|
|||||||
import logging
|
import logging
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import miniohelp as miniohelp
|
import miniohelp as miniohelp
|
||||||
|
from aboutdataset.download_oss import download_and_save_images_from_oss
|
||||||
|
|
||||||
######################################## Logging ########################################
|
######################################## Logging ########################################
|
||||||
def setup_logger(project: str):
|
def setup_logger(project: str):
|
||||||
@ -82,8 +83,12 @@ def setup_logger(project: str):
|
|||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
def get_last_model_from_log(project: str, default_model: str):
|
def get_last_model_from_log(project: str, default_model: str = "yolo11n.pt"):
|
||||||
"""从日志中解析上一次训练的 last.pt 路径"""
|
"""
|
||||||
|
从日志解析上一次训练的 last.pt 路径
|
||||||
|
如果找不到则返回 default_model
|
||||||
|
支持 default_model 为 .pt 或 .yaml
|
||||||
|
"""
|
||||||
log_file = os.path.join("logs", f"{project}.log")
|
log_file = os.path.join("logs", f"{project}.log")
|
||||||
if not os.path.exists(log_file):
|
if not os.path.exists(log_file):
|
||||||
return default_model
|
return default_model
|
||||||
@ -96,8 +101,23 @@ def get_last_model_from_log(project: str, default_model: str):
|
|||||||
path = line.strip().split("Saved last model path:")[-1].strip()
|
path = line.strip().split("Saved last model path:")[-1].strip()
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
return default_model
|
return default_model
|
||||||
|
|
||||||
|
def get_img_and_label_paths(yaml_name, where_clause, image_dir,label_dir, table_name):
|
||||||
|
"""
|
||||||
|
yaml_name='config'
|
||||||
|
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
|
||||||
|
image_dir='images'
|
||||||
|
label_dir='labels'
|
||||||
|
table_name = 'aidataset'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(image_dir, label_dir)
|
||||||
|
"""
|
||||||
|
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
|
||||||
|
return image_dir, label_dir
|
||||||
|
|
||||||
####################################### 工具函数 #######################################
|
####################################### 工具函数 #######################################
|
||||||
def count_labels_by_class(label_dir):
|
def count_labels_by_class(label_dir):
|
||||||
class_counter = Counter()
|
class_counter = Counter()
|
||||||
@ -338,18 +358,36 @@ def auto_train(db_host, db_database, db_user, db_password, db_port, model_id,
|
|||||||
|
|
||||||
logger.info("训练流程执行完成")
|
logger.info("训练流程执行完成")
|
||||||
|
|
||||||
|
def down_and_train(db_host, db_database, db_user, db_password, db_port, model_id, image_dir, label_dir, yaml_name, where_clause, table_name):
|
||||||
|
|
||||||
|
imag_path, label_path = get_img_and_label_paths(yaml_name, where_clause, image_dir, label_dir, table_name)
|
||||||
|
auto_train(
|
||||||
|
db_host=db_host,
|
||||||
|
db_database=db_database,
|
||||||
|
db_user=db_user,
|
||||||
|
db_password=db_password,
|
||||||
|
db_port=db_port,
|
||||||
|
model_id=model_id,
|
||||||
|
imag_path=imag_path, # 修正了这里,确保 imag_path 作为关键字参数传递
|
||||||
|
label_path=label_path, # 修正了这里,确保 label_path 作为关键字参数传递
|
||||||
|
new_path='./datasets',
|
||||||
|
split_list=[0.7, 0.2, 0.1],
|
||||||
|
class_names={'0': 'human', '1': 'car'},
|
||||||
|
project_name='my_project'
|
||||||
|
)
|
||||||
|
|
||||||
####################################### 主入口 #######################################
|
####################################### 主入口 #######################################
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
auto_train(
|
down_and_train(
|
||||||
db_host='222.212.85.86',
|
db_host='222.212.85.86',
|
||||||
db_database='your_database_name',
|
db_database='your_database_name',
|
||||||
db_user='postgres',
|
db_user='postgres',
|
||||||
db_password='postgres',
|
db_password='postgres',
|
||||||
db_port='5432',
|
db_port='5432',
|
||||||
model_id='best.pt',
|
model_id='best.pt',
|
||||||
img_path='./dataset/images',
|
img_path='./dataset/images', #before broken img path
|
||||||
label_path='./dataset/labels',
|
label_path='./dataset/labels',#before broken labels path
|
||||||
new_path='./datasets',
|
new_path='./datasets', #after broken path
|
||||||
split_list=[0.7, 0.2, 0.1],
|
split_list=[0.7, 0.2, 0.1],
|
||||||
class_names={'0': 'human', '1': 'car'},
|
class_names={'0': 'human', '1': 'car'},
|
||||||
project_name='my_project'
|
project_name='my_project'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user