104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
import psycopg2
|
|
import os
|
|
import requests
|
|
from miniohelp import read_sql_config
|
|
import shutil
|
|
|
|
def fetch_object_and_labels(sql_config, where_clause=None, table_name='aidataset'):
|
|
conn = psycopg2.connect(
|
|
host=sql_config['host'],
|
|
port=sql_config['port'],
|
|
user=sql_config['user'],
|
|
password=sql_config['password'],
|
|
dbname=sql_config.get('database') or sql_config.get('dbname')
|
|
)
|
|
try:
|
|
cursor = conn.cursor()
|
|
sql = f"SELECT * FROM public.{table_name}"
|
|
if where_clause:
|
|
sql += f" WHERE {where_clause}"
|
|
|
|
cursor.execute(sql)
|
|
rows = cursor.fetchall()
|
|
return rows
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def download_file_from_url(url, save_path):
|
|
"""
|
|
从 URL 下载文件到指定路径
|
|
"""
|
|
try:
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status()
|
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
with open(save_path, 'wb') as f:
|
|
for chunk in response.iter_content(8192):
|
|
if chunk:
|
|
f.write(chunk)
|
|
print(f"✅ 下载成功:{save_path}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 下载失败: {url} 错误: {e}")
|
|
return False
|
|
|
|
def save_label(label_text, file_path):
|
|
"""
|
|
保存 label 内容到 .txt 文件
|
|
"""
|
|
try:
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
f.write(label_text.strip())
|
|
print(f"✅ 标签已保存:{file_path}")
|
|
except Exception as e:
|
|
print(f"❌ 保存标签失败:{file_path} 错误: {e}")
|
|
|
|
def move_txt_files(src_dir='image', dst_dir='label'):
|
|
os.makedirs(dst_dir, exist_ok=True)
|
|
for filename in os.listdir(src_dir):
|
|
if filename.endswith('.txt'):
|
|
src_path = os.path.join(src_dir, filename)
|
|
dst_path = os.path.join(dst_dir, filename)
|
|
shutil.move(src_path, dst_path)
|
|
print(f"已移动: {src_path} -> {dst_path}")
|
|
|
|
|
|
def download_and_save_images_from_oss(yaml_name, where_clause, image_dir, label_dir, table_name):
|
|
sql_config = read_sql_config(yaml_name)
|
|
|
|
rows = fetch_object_and_labels(sql_config, where_clause, table_name)
|
|
print(f"共查询到 {len(rows)} 条记录")
|
|
|
|
os.makedirs(image_dir, exist_ok=True)
|
|
os.makedirs(label_dir, exist_ok=True)
|
|
|
|
for id, orgcode, model, state, objectname, label in rows:
|
|
if label is None or label.strip() == "":
|
|
# 跳过无标签的
|
|
print(f"跳过无标签记录 id={id}")
|
|
continue
|
|
|
|
full_url = f"{objectname}"
|
|
filename = os.path.basename(objectname)
|
|
name_no_ext = os.path.splitext(filename)[0]
|
|
|
|
image_path = os.path.join(image_dir, filename)
|
|
label_path = os.path.join(label_dir, name_no_ext + ".txt")
|
|
|
|
success = download_file_from_url(full_url, image_path)
|
|
if success:
|
|
save_label(label, label_path)
|
|
print("下载完成/n 下载完成")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
yaml_name='config'
|
|
where_clause="model = '0845315a-0b3c-439d-9e42-264a9411207f'"
|
|
image_dir='images'
|
|
label_dir='labels'
|
|
table_name = 'aidataset'
|
|
download_and_save_images_from_oss(yaml_name, where_clause, image_dir,label_dir, table_name)
|