104 lines
3.2 KiB
Python
Raw Normal View History

2025-10-09 09:29:18 +08:00
import psycopg2
import os
import requests
from miniohelp import read_sql_config
import shutil
def fetch_object_and_labels(sql_config, where_clause=None, table_name='aidataset'):
conn = psycopg2.connect(
host=sql_config['host'],
port=sql_config['port'],
user=sql_config['user'],
password=sql_config['password'],
dbname=sql_config.get('database') or sql_config.get('dbname')
)
try:
cursor = conn.cursor()
sql = f"SELECT * FROM public.{table_name}"
if where_clause:
sql += f" WHERE {where_clause}"
cursor.execute(sql)
rows = cursor.fetchall()
return rows
finally:
conn.close()
def download_file_from_url(url, save_path):
"""
URL 下载文件到指定路径
"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
for chunk in response.iter_content(8192):
if chunk:
f.write(chunk)
print(f"✅ 下载成功:{save_path}")
return True
except Exception as e:
print(f"❌ 下载失败: {url} 错误: {e}")
return False
def save_label(label_text, file_path):
"""
保存 label 内容到 .txt 文件
"""
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(label_text.strip())
print(f"✅ 标签已保存:{file_path}")
except Exception as e:
print(f"❌ 保存标签失败:{file_path} 错误: {e}")
def move_txt_files(src_dir='image', dst_dir='label'):
os.makedirs(dst_dir, exist_ok=True)
for filename in os.listdir(src_dir):
if filename.endswith('.txt'):
src_path = os.path.join(src_dir, filename)
dst_path = os.path.join(dst_dir, filename)
shutil.move(src_path, dst_path)
print(f"已移动: {src_path} -> {dst_path}")
def download_and_save_images_from_oss(yaml_name, where_clause, image_dir, label_dir, table_name):
sql_config = read_sql_config(yaml_name)
rows = fetch_object_and_labels(sql_config, where_clause, table_name)
print(f"共查询到 {len(rows)} 条记录")
os.makedirs(image_dir, exist_ok=True)
os.makedirs(label_dir, exist_ok=True)
for id, orgcode, model, state, objectname, label in rows:
if label is None or label.strip() == "":
# 跳过无标签的
print(f"跳过无标签记录 id={id}")
continue
full_url = f"{objectname}"
filename = os.path.basename(objectname)
name_no_ext = os.path.splitext(filename)[0]
image_path = os.path.join(image_dir, filename)
label_path = os.path.join(label_dir, name_no_ext + ".txt")
success = download_file_from_url(full_url, image_path)
if success:
save_label(label, label_path)
print("下载完成/n 下载完成")
if __name__ == '__main__':
yaml_name='config'
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
image_dir='images'
label_dir='labels'
table_name = 'aidataset'
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)