共享目录AI识别
This commit is contained in:
parent
5615d6b182
commit
06bafccb4e
@ -17,7 +17,9 @@ import matplotlib.pyplot as plt
|
||||
from ultralytics import YOLO
|
||||
|
||||
from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder
|
||||
from util.yolo2pix import yoloseg_to_grid_cells_fixed_v5, draw_grid_on_image
|
||||
from util.yolo2pix_new import *
|
||||
from util.smb import *
|
||||
import threading
|
||||
|
||||
# 定义红白蓝颜色 (BGR格式)
|
||||
RED = (0, 0, 255)
|
||||
@ -183,6 +185,71 @@ class YOLOSegmentationInference:
|
||||
print(f"图片预处理失败: {e}")
|
||||
return None, None
|
||||
|
||||
def perform_inference_share_dir(self, image, image_path, conf_threshold: float = 0.25,
|
||||
iou_threshold: float = 0.5) -> InferenceResult:
|
||||
"""
|
||||
执行推理
|
||||
|
||||
Args:
|
||||
image_path: 图片数据
|
||||
image_path: 图片路径
|
||||
conf_threshold: 置信度阈值
|
||||
iou_threshold: IOU阈值
|
||||
|
||||
Returns:
|
||||
推理结果
|
||||
"""
|
||||
result = InferenceResult(image_path)
|
||||
|
||||
try:
|
||||
if self.model is None:
|
||||
raise ValueError("模型未加载,请先调用load_model()")
|
||||
|
||||
# 转换为RGB格式
|
||||
original_image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
result.original_image = original_image_rgb
|
||||
|
||||
# 执行推理
|
||||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||||
start_time = time.time()
|
||||
|
||||
# 使用YOLO模型进行推理
|
||||
# predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
||||
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0]
|
||||
|
||||
result.inference_time = time.time() - start_time
|
||||
|
||||
# 处理结果
|
||||
if predictions.masks is not None:
|
||||
# 处理掩码
|
||||
masks = predictions.masks.data.cpu().numpy()
|
||||
|
||||
# 处理边界框
|
||||
boxes = predictions.boxes.data.cpu().numpy()
|
||||
|
||||
# 处理类别和置信度
|
||||
class_ids = predictions.boxes.cls.cpu().numpy().astype(int)
|
||||
scores = predictions.boxes.conf.cpu().numpy()
|
||||
|
||||
# 获取类别名称
|
||||
class_names = [self.model.names[i] for i in class_ids]
|
||||
|
||||
# 存储结果
|
||||
result.masks = masks
|
||||
result.boxes = boxes
|
||||
result.classes = class_ids
|
||||
result.scores = scores
|
||||
result.class_names = class_names
|
||||
|
||||
print(f"检测到 {len(masks)} 个对象,推理时间: {result.inference_time:.3f} 秒")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"推理失败: {e}")
|
||||
return result
|
||||
|
||||
def perform_inference(self, image_path: str, conf_threshold: float = 0.25,
|
||||
iou_threshold: float = 0.5) -> InferenceResult:
|
||||
"""
|
||||
@ -323,7 +390,7 @@ class YOLOSegmentationInference:
|
||||
|
||||
try:
|
||||
base_name = os.path.splitext(os.path.basename(result.image_path))[0]
|
||||
output_dir = output_dir + "-" + base_name
|
||||
# output_dir = output_dir + "/" + base_name
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@ -334,14 +401,15 @@ class YOLOSegmentationInference:
|
||||
"result_dir": output_dir
|
||||
}
|
||||
# 保存结果图片
|
||||
result_path = os.path.join(output_dir, f"{base_name}_result.jpg")
|
||||
result_path = os.path.join(output_dir, f"{base_name}.jpg")
|
||||
result_image_bgr = cv2.cvtColor(result.result_image, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(result_path, result_image_bgr)
|
||||
print(f"结果图片已保存: {result_path}")
|
||||
|
||||
# 保存单独的掩码文件
|
||||
if save_mask and result.masks is not None and len(result.masks) > 0:
|
||||
mask_dir = os.path.join(output_dir, "masks")
|
||||
# mask_dir = os.path.join(output_dir, "masks")
|
||||
mask_dir = output_dir
|
||||
os.makedirs(mask_dir, exist_ok=True)
|
||||
|
||||
for i in range(len(result.masks)):
|
||||
@ -356,7 +424,8 @@ class YOLOSegmentationInference:
|
||||
|
||||
# 保存YOLO格式的标签文件
|
||||
if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0:
|
||||
label_dir = os.path.join(output_dir, "labels")
|
||||
# label_dir = os.path.join(output_dir, "labels")
|
||||
label_dir = output_dir
|
||||
os.makedirs(label_dir, exist_ok=True)
|
||||
|
||||
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
||||
@ -464,6 +533,45 @@ class YOLOSegmentationInference:
|
||||
|
||||
return result
|
||||
|
||||
def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None,
|
||||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||||
save_mask: bool = False, save_label: bool = False, show: bool = True,
|
||||
result_save: [] = None) -> InferenceResult:
|
||||
"""
|
||||
处理单张图片
|
||||
|
||||
Args:
|
||||
image_path: 图片路径
|
||||
output_dir: 输出目录,如果为None则不保存
|
||||
conf_threshold: 置信度阈值
|
||||
iou_threshold: IOU阈值
|
||||
save_mask: 是否保存单独的掩码文件
|
||||
save_label: 是否保存YOLO格式的标签文件
|
||||
show: 是否显示结果
|
||||
|
||||
Returns:
|
||||
推理结果
|
||||
"""
|
||||
# 执行推理
|
||||
config = get_conf(image_path, user_name, pwd)
|
||||
scanner = get_scanner(image_path, user_name=user_name, pwd=pwd)
|
||||
image = scanner.read_img_file(image_path)
|
||||
result = self.perform_inference_share_dir(image, image_path, conf_threshold, iou_threshold)
|
||||
|
||||
# 绘制结果
|
||||
if result.masks is not None and len(result.masks) > 0:
|
||||
self.draw_results(result, conf_threshold)
|
||||
|
||||
# 保存结果
|
||||
if output_dir is not None:
|
||||
self.save_results(result, output_dir, save_mask, save_label, result_save)
|
||||
|
||||
# # 显示结果
|
||||
# if show:
|
||||
# self.show_results(result)
|
||||
|
||||
return result
|
||||
|
||||
def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None,
|
||||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||||
save_mask: bool = False, save_label: bool = False, show: bool = False,
|
||||
@ -528,6 +636,83 @@ class YOLOSegmentationInference:
|
||||
print(f"处理目录失败: {e}")
|
||||
return results
|
||||
|
||||
def process_image_directory_share_dir(self, input_dir, user_name, pwd, output_dir: Optional[str] = None,
|
||||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||||
save_mask: bool = False, save_label: bool = False, show: bool = False,
|
||||
result_save: [] = None) -> List[
|
||||
InferenceResult]:
|
||||
"""
|
||||
处理目录中的所有图片
|
||||
|
||||
Args:
|
||||
input_dir: 输入目录
|
||||
output_dir: 输出目录,如果为None则不保存
|
||||
conf_threshold: 置信度阈值
|
||||
iou_threshold: IOU阈值
|
||||
save_mask: 是否保存单独的掩码文件
|
||||
save_label: 是否保存YOLO格式的标签文件
|
||||
show: 是否显示结果
|
||||
|
||||
Returns:
|
||||
推理结果列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
try:
|
||||
# 检查目录是否存在
|
||||
config = get_conf(zip_url=input_dir, user_name=user_name, pwd=pwd)
|
||||
scanner = get_scanner(zip_url=input_dir, user_name=user_name, pwd=pwd);
|
||||
if not scanner.directory_exists(input_dir):
|
||||
print(f"错误: {input_dir} 不是有效的目录")
|
||||
return results
|
||||
|
||||
# 获取所有图片文件
|
||||
image_files = scanner.get_smb_images(input_dir)
|
||||
|
||||
if not image_files:
|
||||
print(f"在目录 {input_dir} 中未找到图片文件")
|
||||
return results
|
||||
|
||||
print(f"找到 {len(image_files)} 个图片文件")
|
||||
|
||||
# 处理每张图片
|
||||
for image_path in image_files:
|
||||
result = self.process_single_image_share_dir(
|
||||
image_path=image_path,
|
||||
user_name=user_name,
|
||||
pwd=pwd,
|
||||
output_dir=output_dir,
|
||||
conf_threshold=conf_threshold,
|
||||
iou_threshold=iou_threshold,
|
||||
save_mask=save_mask,
|
||||
save_label=save_label,
|
||||
show=show,
|
||||
result_save=result_save
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# 推送识别数据到共享目录
|
||||
pile_dict = get_pile_dict(image_path, user_name, pwd)
|
||||
process_dir(pile_dict, output_dir)
|
||||
# 找到 图像类 文件夹
|
||||
found_paths = scanner.find_folders_by_name(
|
||||
share_path=config['share'],
|
||||
folder_name='图像类'
|
||||
)
|
||||
if len(found_paths) > 0 :
|
||||
tmpConf = get_conf(found_paths[0], user_name, pwd)
|
||||
scanner.upload_directory(output_dir, config['share'], remote_dir=tmpConf['dir']+"_识别")
|
||||
else :
|
||||
print(f"错误: 远程共享目录 找不到【图像类】目录")
|
||||
|
||||
return results
|
||||
|
||||
except PermissionError:
|
||||
print(f"权限错误: 无法访问目录 {input_dir}")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"处理目录失败: {e}")
|
||||
return results
|
||||
|
||||
def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False):
|
||||
zip_save_path = "dataset/zip_file"
|
||||
@ -596,8 +781,8 @@ def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.
|
||||
input_path = zip_local_dir_save
|
||||
result_save = []
|
||||
|
||||
conf_threshold = 0.25,
|
||||
iou_threshold = 0.5,
|
||||
conf_threshold = 0.25
|
||||
iou_threshold = 0.5
|
||||
save_mask = True,
|
||||
save_label = True,
|
||||
show = True
|
||||
@ -673,3 +858,85 @@ def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.
|
||||
os.remove(zip_dir_path)
|
||||
|
||||
return file_save_dir, "success"
|
||||
|
||||
|
||||
|
||||
|
||||
def predict_images_share_dir(pt_name, zip_url, user_name, pwd, output_dir="predictions", conf_threshold=0.25, save_json=False):
|
||||
# 本地测试模式 - 请根据实际情况修改以下路径
|
||||
# local_model_path = r"D:\project\verification\ultralytics-main\model\script\seg\pt\test.pttest.pt"
|
||||
local_model_path = r"../pt_save/road_crack.pt"
|
||||
local_output_dir = output_dir
|
||||
# zip_url = "meta_data/ai_train_platform/train.zip"
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
print(f"正在加载模型: {local_model_path}")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"模型已加载到 {device}")
|
||||
except Exception as e:
|
||||
print(f"处理目录失败: {e}")
|
||||
|
||||
inference = YOLOSegmentationInference(
|
||||
model_path=local_model_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 加载模型
|
||||
if not inference.load_model():
|
||||
return
|
||||
|
||||
|
||||
|
||||
# zip_url = r"D:\project\verification\ultralytics-main\model\script\seg\test_seg_pic"
|
||||
|
||||
result_save = []
|
||||
conf_threshold = 0.25
|
||||
iou_threshold = 0.5
|
||||
save_mask = False
|
||||
save_label = True
|
||||
show = True
|
||||
|
||||
# 查找指定文件夹 图像类
|
||||
config = get_conf(zip_url, user_name, pwd)
|
||||
scanner = get_scanner(zip_url, user_name, pwd)
|
||||
found_paths = scanner.find_folders_by_name(
|
||||
share_path=config['share'],
|
||||
folder_name='图像类'
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
report_data_path = ""
|
||||
if len(found_paths) > 0:
|
||||
# 处理目录
|
||||
report_data_path = found_paths[0]
|
||||
tmpConfig = get_conf(report_data_path, user_name, pwd)
|
||||
found_paths = scanner.find_folders_by_name(
|
||||
share_path=config['share'],
|
||||
folder_name='Images',
|
||||
start_dir=tmpConfig['dir']
|
||||
)
|
||||
if len(found_paths) > 0:
|
||||
target_path = found_paths[0]
|
||||
# inference.process_image_directory_share_dir(
|
||||
# input_dir=target_path,
|
||||
# user_name=user_name,
|
||||
# pwd=pwd,
|
||||
# output_dir=output_dir,
|
||||
# conf_threshold=conf_threshold,
|
||||
# iou_threshold=iou_threshold,
|
||||
# save_mask=save_mask,
|
||||
# save_label=save_label,
|
||||
# show=show,
|
||||
# result_save=result_save
|
||||
# )
|
||||
|
||||
# 创建并启动线程
|
||||
thread1 = threading.Thread(target=inference.process_image_directory_share_dir, args=(target_path,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save))
|
||||
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
else:
|
||||
print(f"错误: 输入 {zip_url} 不是有效的文件或目录")
|
||||
|
||||
return f"{report_data_path}_识别", "success"
|
||||
859
util/smb.py
Normal file
859
util/smb.py
Normal file
@ -0,0 +1,859 @@
|
||||
import os
|
||||
from smbclient import (
|
||||
register_session,
|
||||
listdir,
|
||||
scandir,
|
||||
stat,
|
||||
makedirs, # 递归创建目录
|
||||
open_file
|
||||
)
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
class SMBScanner:
|
||||
def __init__(self, ip, username, password, domain=''):
|
||||
self.ip = ip
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.domain = domain
|
||||
|
||||
def connect(self):
|
||||
"""连接 SMB 共享"""
|
||||
try:
|
||||
register_session(
|
||||
self.ip,
|
||||
username=self.username,
|
||||
password=self.password
|
||||
)
|
||||
print(f"成功连接到 {self.ip}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"连接失败: {e}")
|
||||
return False
|
||||
|
||||
def directory_exists(self, full_path):
|
||||
"""
|
||||
检查目录是否存在
|
||||
|
||||
Args:
|
||||
full_path: 全路径
|
||||
|
||||
Returns:
|
||||
bool: 目录是否存在
|
||||
"""
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
try:
|
||||
# 尝试获取目录信息
|
||||
dir_stat = stat(full_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"未知错误: {e}")
|
||||
return False
|
||||
|
||||
def read_excel(self, smb_path, sheet_name=0):
|
||||
"""读取Excel文件"""
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
try:
|
||||
with open_file(smb_path, mode='rb') as smb_file:
|
||||
file_content = smb_file.read()
|
||||
|
||||
excel_data = io.BytesIO(file_content)
|
||||
df = pd.read_excel(excel_data, sheet_name=sheet_name)
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取Excel失败: {e}")
|
||||
return None
|
||||
|
||||
def process_all_rows(self, df):
|
||||
"""
|
||||
处理所有行数据
|
||||
"""
|
||||
if df is None or df.empty:
|
||||
print("没有数据可处理")
|
||||
return
|
||||
|
||||
print("开始处理每行数据:")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
for row_number, (index, row) in enumerate(df.iterrows(), 1):
|
||||
print(f"\n处理第 {row_number} 行:")
|
||||
print("-" * 40)
|
||||
|
||||
# 显示行数据
|
||||
for col_name in df.columns:
|
||||
value = row[col_name]
|
||||
print(f" {col_name}: {value}")
|
||||
|
||||
# 处理逻辑(根据实际需求修改)
|
||||
processed_row = {
|
||||
'row_number': row_number,
|
||||
'original_index': index,
|
||||
'data': row.to_dict(),
|
||||
'summary': f"处理了 {len(df.columns)} 个字段"
|
||||
}
|
||||
|
||||
results.append(processed_row)
|
||||
|
||||
# 进度显示
|
||||
if row_number % 10 == 0 or row_number == len(df):
|
||||
print(f"\n 进度: {row_number}/{len(df)} ({row_number/len(df)*100:.1f}%)")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成!共处理 {len(results)} 行数据")
|
||||
|
||||
return results
|
||||
|
||||
def get_smb_images(self, full_path):
|
||||
"""SMB 图片文件获取"""
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
|
||||
image_files = []
|
||||
|
||||
try:
|
||||
for entry in scandir(full_path):
|
||||
if entry.is_file():
|
||||
_, ext = os.path.splitext(entry.name)
|
||||
if ext.lower() in image_extensions:
|
||||
image_files.append(entry.path)
|
||||
elif entry.is_dir():
|
||||
imgs = self.get_smb_images(entry.path)
|
||||
image_files.extend(imgs)
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
return image_files
|
||||
|
||||
def build_full_path(self, share_path, file_path):
|
||||
"""构建完整的 SMB 路径"""
|
||||
# 清理路径中的多余斜杠
|
||||
share_path = share_path.strip('\\')
|
||||
file_path = file_path.lstrip('\\')
|
||||
return f"\\\\{self.ip}\\{share_path}\\{file_path}"
|
||||
|
||||
def read_txt_by_line(self, full_path):
|
||||
"""逐行读取,适合大文件"""
|
||||
if not self.connect():
|
||||
return None
|
||||
|
||||
print(f"读取 TXT 文件: {full_path}")
|
||||
|
||||
try:
|
||||
with open_file(full_path, mode='rb') as file_obj:
|
||||
content_bytes = file_obj.read()
|
||||
|
||||
# 使用 StringIO 逐行处理
|
||||
text_content = content_bytes.decode('utf-8', errors='ignore')
|
||||
string_io = io.StringIO(text_content)
|
||||
|
||||
lines = []
|
||||
line_number = 0
|
||||
|
||||
while True:
|
||||
line = string_io.readline()
|
||||
if not line: # 读到文件末尾
|
||||
break
|
||||
|
||||
line_number += 1
|
||||
line = line.strip()
|
||||
# print(f"行 {line_number}: {line}")
|
||||
lines.append(line)
|
||||
|
||||
print(f"总共读取 {line_number} 行")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取文件时出错: {e}")
|
||||
return None
|
||||
|
||||
def read_img_file(self, full_path):
|
||||
"""读取文件并返回 OpenCV 图像"""
|
||||
if not self.connect():
|
||||
return None
|
||||
|
||||
print(f"读取文件: {full_path}")
|
||||
|
||||
file_obj = None
|
||||
try:
|
||||
# 以二进制模式读取文件
|
||||
file_obj = open_file(full_path, mode='rb')
|
||||
content = b""
|
||||
|
||||
# 分块读取文件内容
|
||||
while True:
|
||||
chunk = file_obj.read(8192) # 8KB 块
|
||||
if not chunk:
|
||||
break
|
||||
content += chunk
|
||||
|
||||
print(f"成功读取 {len(content)} 字节")
|
||||
|
||||
# 解码图像
|
||||
if len(content) == 0:
|
||||
print("文件为空")
|
||||
return None
|
||||
|
||||
image_array = np.frombuffer(content, np.uint8)
|
||||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
print("图像解码失败 - 可能不是有效的图像文件")
|
||||
return None
|
||||
|
||||
print(f"图像解码成功: {image.shape}")
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取文件失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
if file_obj:
|
||||
file_obj.close()
|
||||
|
||||
def writeFile(self, share_path, file_path, data, chunk_size=8192):
|
||||
"""写入文件到 SMB 共享"""
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
full_path = self.build_full_path(share_path, file_path)
|
||||
file_obj = None
|
||||
|
||||
try:
|
||||
# 确保目录存在
|
||||
dir_path = os.path.dirname(full_path)
|
||||
try:
|
||||
makedirs(dir_path, exist_ok=True)
|
||||
except:
|
||||
pass # 目录可能已存在
|
||||
|
||||
file_obj = open_file(full_path, mode='wb')
|
||||
|
||||
if isinstance(data, bytes):
|
||||
total_size = len(data)
|
||||
written = 0
|
||||
|
||||
for i in range(0, total_size, chunk_size):
|
||||
chunk = data[i:i + chunk_size]
|
||||
file_obj.write(chunk)
|
||||
written += len(chunk)
|
||||
print(f"写入进度: {written}/{total_size} 字节 ({written/total_size*100:.1f}%)")
|
||||
|
||||
elif hasattr(data, '__iter__'):
|
||||
total_written = 0
|
||||
for chunk in data:
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode('utf-8')
|
||||
file_obj.write(chunk)
|
||||
total_written += len(chunk)
|
||||
print(f"已写入: {total_written} 字节")
|
||||
else:
|
||||
file_obj.write(bytes(data))
|
||||
|
||||
print(f"文件写入完成: {full_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"写入文件失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
if file_obj:
|
||||
file_obj.close()
|
||||
|
||||
def writeImageToFile(self, share_path, file_path, image, image_format='.jpg', quality=95):
|
||||
"""将 OpenCV 图像写入 SMB 文件"""
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
full_path = f"{file_path}{image_format}"
|
||||
file_obj = None
|
||||
|
||||
try:
|
||||
if image_format.lower() == '.jpg':
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, encoded_image = cv2.imencode(image_format, image, encode_params)
|
||||
else:
|
||||
success, encoded_image = cv2.imencode(image_format, image)
|
||||
|
||||
if not success:
|
||||
print("图像编码失败")
|
||||
return False
|
||||
|
||||
image_data = encoded_image.tobytes()
|
||||
return self.writeFile(share_path, f"{file_path}{image_format}", image_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"写入图像失败: {e}")
|
||||
return False
|
||||
|
||||
def _ensure_remote_directory(self, share_name, remote_dir):
|
||||
"""确保远程目录存在"""
|
||||
if not remote_dir:
|
||||
return
|
||||
|
||||
try:
|
||||
# 构建完整远程路径
|
||||
full_remote_path = self.build_full_path(share_name, remote_dir)
|
||||
|
||||
# 使用 makedirs 递归创建目录(如果不存在)
|
||||
makedirs(full_remote_path, exist_ok=True)
|
||||
print(f"确保远程目录存在: {remote_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建远程目录失败: {e}")
|
||||
raise
|
||||
|
||||
def upload_directory(self, local_dir, share_name, remote_dir="", overwrite=True):
|
||||
"""
|
||||
将本地目录推送到远程共享目录
|
||||
"""
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
print(f"开始上传目录: {local_dir} -> {share_name}/{remote_dir}")
|
||||
|
||||
if not os.path.exists(local_dir):
|
||||
print(f"本地目录不存在: {local_dir}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 确保远程目录存在
|
||||
self._ensure_remote_directory(share_name, remote_dir)
|
||||
|
||||
# 递归上传目录内容
|
||||
success = self._upload_directory_recursive(local_dir, share_name, remote_dir, overwrite)
|
||||
|
||||
if success:
|
||||
print("目录上传完成")
|
||||
else:
|
||||
print("目录上传过程中出现错误")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
print(f"上传目录失败: {e}")
|
||||
return False
|
||||
|
||||
def _upload_directory_recursive(self, local_path, share_name, remote_path, overwrite):
|
||||
"""递归上传目录内容"""
|
||||
try:
|
||||
success = True
|
||||
|
||||
for item_name in os.listdir(local_path):
|
||||
local_item_path = os.path.join(local_path, item_name)
|
||||
remote_item_path = f"{remote_path}/{item_name}" if remote_path else item_name
|
||||
|
||||
if os.path.isdir(local_item_path):
|
||||
# 处理子目录
|
||||
print(f"上传子目录: {item_name}")
|
||||
|
||||
# 确保远程子目录存在
|
||||
self._ensure_remote_directory(share_name, remote_item_path)
|
||||
|
||||
# 递归上传子目录
|
||||
sub_success = self._upload_directory_recursive(local_item_path, share_name, remote_item_path, overwrite)
|
||||
if not sub_success:
|
||||
success = False
|
||||
|
||||
else:
|
||||
# 上传文件
|
||||
file_success = self._upload_single_file(local_item_path, share_name, remote_item_path, overwrite)
|
||||
if not file_success:
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
print(f"上传目录内容失败 {local_path}: {e}")
|
||||
return False
|
||||
|
||||
def _upload_single_file(self, local_file_path, share_name, remote_file_path, overwrite):
|
||||
"""上传单个文件"""
|
||||
file_obj = None
|
||||
try:
|
||||
# 构建远程完整路径
|
||||
full_remote_path = self.build_full_path(share_name, remote_file_path)
|
||||
|
||||
# 检查文件是否已存在
|
||||
if not overwrite:
|
||||
try:
|
||||
stat(full_remote_path)
|
||||
print(f"文件已存在,跳过: {remote_file_path}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
# 文件不存在,继续上传
|
||||
pass
|
||||
|
||||
# 上传文件
|
||||
print(f"上传文件: {os.path.basename(local_file_path)}")
|
||||
|
||||
# 读取本地文件
|
||||
with open(local_file_path, 'rb') as local_file:
|
||||
local_content = local_file.read()
|
||||
|
||||
# 写入远程文件
|
||||
with open_file(full_remote_path, mode='wb') as remote_file:
|
||||
remote_file.write(local_content)
|
||||
|
||||
file_size = len(local_content)
|
||||
print(f"文件上传成功: {remote_file_path} ({file_size} 字节)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"上传文件失败 {local_file_path}: {e}")
|
||||
return False
|
||||
|
||||
def upload_file(self, local_file_path, share_name, remote_file_path, overwrite=True):
|
||||
"""
|
||||
上传单个文件到远程共享目录
|
||||
"""
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
print(f"上传文件: {local_file_path} -> {share_name}/{remote_file_path}")
|
||||
|
||||
file_obj = None
|
||||
try:
|
||||
# 构建远程完整路径
|
||||
full_remote_path = self.build_full_path(share_name, remote_file_path)
|
||||
|
||||
# 检查文件是否已存在
|
||||
if not overwrite:
|
||||
try:
|
||||
stat(full_remote_path)
|
||||
print(f"文件已存在,跳过: {remote_file_path}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
# 文件不存在,继续上传
|
||||
pass
|
||||
|
||||
# 以二进制模式读取本地文件
|
||||
with open(local_file_path, 'rb') as local_file:
|
||||
content = b""
|
||||
|
||||
# 分块读取文件内容
|
||||
while True:
|
||||
chunk = local_file.read(8192) # 8KB 块
|
||||
if not chunk:
|
||||
break
|
||||
content += chunk
|
||||
|
||||
print(f"成功读取 {len(content)} 字节")
|
||||
|
||||
if len(content) == 0:
|
||||
print("文件为空")
|
||||
return False
|
||||
|
||||
# 写入远程文件
|
||||
with open_file(full_remote_path, mode='wb') as remote_file:
|
||||
remote_file.write(content)
|
||||
|
||||
print(f"文件上传成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"上传文件失败: {e}")
|
||||
return False
|
||||
|
||||
def find_folders_by_name(self, share_path, folder_name, start_dir="", max_depth=10):
|
||||
"""专门查找文件夹"""
|
||||
return self.find_items_by_name(
|
||||
share_path=share_path,
|
||||
target_name=folder_name,
|
||||
item_type="folder",
|
||||
start_dir=start_dir,
|
||||
max_depth=max_depth
|
||||
)
|
||||
|
||||
def find_files_by_name(self, share_path, file_name, start_dir="", max_depth=10):
|
||||
"""专门查找文件"""
|
||||
return self.find_items_by_name(
|
||||
share_path=share_path,
|
||||
target_name=file_name,
|
||||
item_type="file",
|
||||
start_dir=start_dir,
|
||||
max_depth=max_depth
|
||||
)
|
||||
|
||||
def find_items_by_name(self, share_path, target_name, item_type="both", start_dir="", max_depth=10):
|
||||
"""
|
||||
递归查找指定名称的文件夹和/或文件
|
||||
|
||||
Args:
|
||||
share_path: 共享名称
|
||||
target_name: 目标名称(支持通配符 * 和 ?)
|
||||
item_type: 查找类型 - "folder", "file", "both"
|
||||
start_dir: 起始目录
|
||||
max_depth: 最大搜索深度
|
||||
|
||||
Returns:
|
||||
list: 找到的完整路径列表
|
||||
"""
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
found_paths = []
|
||||
start_path = self.build_full_path(share_path, start_dir)
|
||||
|
||||
try:
|
||||
self._search_recursive(
|
||||
share_path=share_path,
|
||||
current_path=start_path,
|
||||
target_name=target_name,
|
||||
item_type=item_type,
|
||||
found_paths=found_paths,
|
||||
current_depth=0,
|
||||
max_depth=max_depth
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"搜索过程中出错: {e}")
|
||||
|
||||
return found_paths
|
||||
|
||||
def _search_recursive(self, share_path, current_path, target_name, item_type, found_paths, current_depth, max_depth):
|
||||
"""递归搜索文件夹和文件"""
|
||||
if current_depth > max_depth:
|
||||
return
|
||||
|
||||
try:
|
||||
for entry in scandir(current_path):
|
||||
try:
|
||||
# 检查文件夹
|
||||
if entry.is_dir():
|
||||
if self._is_match(entry.name, target_name) and item_type in ["both", "folder"]:
|
||||
found_paths.append(entry.path)
|
||||
print(f"找到目标文件夹: {entry.path}")
|
||||
|
||||
# 递归搜索子目录
|
||||
self._search_recursive(
|
||||
share_path=share_path,
|
||||
current_path=entry.path,
|
||||
target_name=target_name,
|
||||
item_type=item_type,
|
||||
found_paths=found_paths,
|
||||
current_depth=current_depth + 1,
|
||||
max_depth=max_depth
|
||||
)
|
||||
|
||||
# 检查文件
|
||||
elif entry.is_file():
|
||||
if self._is_match(entry.name, target_name) and item_type in ["both", "file"]:
|
||||
found_paths.append(entry.path)
|
||||
print(f"找到目标文件: {entry.path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理条目 {entry.path} 时出错: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索目录 {current_path} 时出错: {e}")
|
||||
|
||||
def _is_match(self, name, pattern):
|
||||
"""
|
||||
检查名称是否匹配模式(支持简单通配符)
|
||||
|
||||
Args:
|
||||
name: 实际名称
|
||||
pattern: 匹配模式(支持 * 和 ?)
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
# 如果没有通配符,直接比较
|
||||
if '*' not in pattern and '?' not in pattern:
|
||||
return name.lower() == pattern.lower()
|
||||
|
||||
# 通配符匹配
|
||||
import fnmatch
|
||||
return fnmatch.fnmatch(name.lower(), pattern.lower())
|
||||
|
||||
def list_directory(self, share_path, dir, recursive=False, max_depth=3):
|
||||
"""列出目录内容"""
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
full_path = f"\\\\{self.ip}\\{share_path}\\{dir}"
|
||||
print(f"开始遍历: {full_path}")
|
||||
result = []
|
||||
self._walk_directory(full_path, recursive, max_depth, 0, result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"遍历失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def _walk_directory(self, path, recursive, max_depth, current_depth, result):
|
||||
"""递归遍历目录"""
|
||||
if current_depth > max_depth:
|
||||
return
|
||||
|
||||
try:
|
||||
for entry in scandir(path):
|
||||
try:
|
||||
file_stat = stat(entry.path)
|
||||
indent = " " * current_depth
|
||||
# 创建条目信息字典
|
||||
item = {
|
||||
'name': entry.name,
|
||||
'path': entry.path,
|
||||
'depth': current_depth,
|
||||
'indent': indent,
|
||||
'is_dir': entry.is_dir(),
|
||||
'size': file_stat.st_size if not entry.is_dir() else 0,
|
||||
'modified_time': datetime.fromtimestamp(file_stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
if entry.is_dir():
|
||||
# print(f"{indent}文件夹:{entry.name}/")
|
||||
result.append(item)
|
||||
if recursive and current_depth < max_depth:
|
||||
sub_items = self._walk_directory(
|
||||
entry.path,
|
||||
recursive,
|
||||
max_depth,
|
||||
current_depth + 1
|
||||
)
|
||||
result.extend(sub_items)
|
||||
else:
|
||||
file_size = self._format_size(file_stat.st_size)
|
||||
mod_time = datetime.fromtimestamp(
|
||||
file_stat.st_mtime
|
||||
).strftime('%Y-%m-%d %H:%M:%S')
|
||||
# print(f"{indent}文件:{entry.name} [{file_size}] [{mod_time}]")
|
||||
item['formatted_size'] = file_size
|
||||
result.append(item)
|
||||
|
||||
except Exception as e:
|
||||
print(f"{indent} 无法访问: {entry.name} - {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"无法读取目录 {path}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def _format_size(self, size_bytes):
|
||||
"""格式化文件大小"""
|
||||
if size_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
size_names = ["B", "KB", "MB", "GB", "TB"]
|
||||
i = 0
|
||||
while size_bytes >= 1024 and i < len(size_names) - 1:
|
||||
size_bytes /= 1024.0
|
||||
i += 1
|
||||
|
||||
return f"{size_bytes:.1f} {size_names[i]}"
|
||||
|
||||
def get_file_info(self, share_path, file_path):
|
||||
"""获取文件详细信息"""
|
||||
if not self.connect():
|
||||
return None
|
||||
|
||||
try:
|
||||
full_path = f"\\\\{self.ip}\\{share_path}\\{file_path}"
|
||||
file_stat = stat(full_path)
|
||||
|
||||
return {
|
||||
'name': os.path.basename(file_path),
|
||||
'path': full_path,
|
||||
'size': file_stat.st_size,
|
||||
'size_formatted': self._format_size(file_stat.st_size),
|
||||
'create_time': datetime.fromtimestamp(file_stat.st_ctime),
|
||||
'modify_time': datetime.fromtimestamp(file_stat.st_mtime),
|
||||
'access_time': datetime.fromtimestamp(file_stat.st_atime),
|
||||
'is_dir': False # 需要额外判断
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取文件信息失败: {e}")
|
||||
return None
|
||||
|
||||
def display_image(self, image, window_name="Image"):
|
||||
"""
|
||||
显示图像
|
||||
|
||||
Args:
|
||||
image: OpenCV图像
|
||||
window_name: 窗口名称
|
||||
"""
|
||||
# 创建窗口
|
||||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||
|
||||
# 调整窗口大小适应屏幕
|
||||
screen_width = 1920 # 可根据实际屏幕调整
|
||||
screen_height = 1080
|
||||
|
||||
img_height, img_width = image.shape[:2]
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(screen_width / img_width, screen_height / img_height, 1.0)
|
||||
|
||||
if scale < 1.0:
|
||||
new_width = int(img_width * scale)
|
||||
new_height = int(img_height * scale)
|
||||
image = cv2.resize(image, (new_width, new_height))
|
||||
print(f"图像已缩放: {img_width}x{img_height} -> {new_width}x{new_height}")
|
||||
|
||||
# 显示图像
|
||||
cv2.imshow(window_name, image)
|
||||
print("图像显示中... 按任意键关闭窗口")
|
||||
|
||||
# 等待按键
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
print("窗口已关闭")
|
||||
|
||||
# 从传入的路径中提取ip,共享目录,目标访问目录
|
||||
def get_conf(zip_url, user_name, pwd) :
|
||||
zip_url = zip_url.replace('\\\\', '/')
|
||||
zip_url = zip_url.replace('\\', '/')
|
||||
if zip_url.startswith("/"):
|
||||
zip_url = zip_url.replace('/', '', 1)
|
||||
|
||||
parts = zip_url.split('/')
|
||||
if len(parts) < 2 :
|
||||
print(f"传入的共享目录格式错误: {zip_url}")
|
||||
return "", "fail"
|
||||
|
||||
dir = ''
|
||||
if len(parts) > 2:
|
||||
new_parts = parts[2:]
|
||||
dir = '/'.join(new_parts)
|
||||
|
||||
# 配置信息
|
||||
config = {
|
||||
'ip': parts[0],
|
||||
'username': user_name,
|
||||
'password': pwd,
|
||||
'domain': '', # 工作组留空
|
||||
'share': parts[1],
|
||||
'dir': dir
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_scanner(zip_url, user_name, pwd) :
|
||||
config = get_conf(zip_url, user_name, pwd)
|
||||
|
||||
# 创建扫描器
|
||||
scanner = SMBScanner(
|
||||
ip=config['ip'],
|
||||
username=config['username'],
|
||||
password=config['password'],
|
||||
domain=config['domain']
|
||||
)
|
||||
return scanner
|
||||
|
||||
# filename -> 桩号
|
||||
def get_pile_dict(dir,user_name,pwd) :
|
||||
config = get_conf(dir, user_name, pwd)
|
||||
scanner = get_scanner(dir, user_name=user_name, pwd=pwd)
|
||||
found_paths = scanner.find_files_by_name(
|
||||
share_path=config['share'],
|
||||
file_name='fileindex.txt'
|
||||
)
|
||||
print(f"\n找到 {len(found_paths)} 个 'fileindex.txt' 文件:")
|
||||
for i, path in enumerate(found_paths, 1):
|
||||
print(f"{i}. {path}")
|
||||
|
||||
lines = scanner.read_txt_by_line(full_path=found_paths[0])
|
||||
|
||||
pile_dict = {}
|
||||
for i, line in enumerate(lines, 1):
|
||||
parts = line.strip().split("->")
|
||||
if len(parts)>=4:
|
||||
pile_dict[parts[3]]=parts[1] # filename -> 桩号
|
||||
|
||||
return pile_dict
|
||||
|
||||
def main():
|
||||
# 配置信息
|
||||
config = {
|
||||
'ip': '192.168.110.114',
|
||||
'username': 'administrator',
|
||||
'password': 'abc@1234',
|
||||
'domain': '', # 工作组留空
|
||||
'share': 'share_File',
|
||||
'dir': '西南计算机'
|
||||
}
|
||||
|
||||
# 创建扫描器
|
||||
scanner = SMBScanner(
|
||||
ip=config['ip'],
|
||||
username=config['username'],
|
||||
password=config['password'],
|
||||
domain=config['domain']
|
||||
)
|
||||
|
||||
# 遍历共享目录
|
||||
# scanner.list_directory(
|
||||
# share_path=config['share'],
|
||||
# dir=config['dir'],
|
||||
# recursive=True, # 递归遍历
|
||||
# max_depth=9 # 最大深度
|
||||
# )
|
||||
|
||||
# 读取文件
|
||||
# full_path = scanner.build_full_path(
|
||||
# share_path=config['share'],
|
||||
# file_path= f"{config['dir']}/AA县/报送数据/图像类/CD45500155A/Images/20250508131651/01/20250508-131712-644.jpg"
|
||||
# )
|
||||
# image = scanner.read_img_file(full_path=full_path)
|
||||
|
||||
# scanner.display_image(image)
|
||||
|
||||
# # 写入文件
|
||||
# scanner.writeImageToFile(
|
||||
# share_path=config['share'],
|
||||
# file_path= f"{config['dir']}/AA县/报送数据/图像类_识别/CD45500155A/Images/20250508131651/01/20250508-131712-644.jpg",
|
||||
# image=image
|
||||
# )
|
||||
|
||||
# # 查找指定文件夹 报送数据
|
||||
# found_paths = scanner.find_folders_by_name(
|
||||
# share_path=config['share'],
|
||||
# folder_name='报送数据'
|
||||
# )
|
||||
# print(f"\n找到 {len(found_paths)} 个 '报送数据' 文件夹:")
|
||||
# for i, path in enumerate(found_paths, 1):
|
||||
# print(f"{i}. {path}")
|
||||
|
||||
|
||||
# # 查找指定目录中的所有图片
|
||||
# full_path = scanner.build_full_path(share_path=config['share'], file_path='西南计算机\\AA县\\报送数据')
|
||||
# imgPaths = scanner.get_smb_images(full_path)
|
||||
# for i, path in enumerate(imgPaths, 1):
|
||||
# print(f"{i}. {path}")
|
||||
|
||||
# # 读取excel
|
||||
# full_path = scanner.build_full_path(share_path=config['share'], file_path='西南计算机\\AA县\\24年年报.xlsx')
|
||||
# df = scanner.read_excel(full_path)
|
||||
# scanner.process_all_rows(df)
|
||||
|
||||
|
||||
# 读取txt
|
||||
# found_paths = scanner.find_files_by_name(
|
||||
# share_path=config['share'],
|
||||
# file_name='fileindex.txt'
|
||||
# )
|
||||
# print(f"\n找到 {len(found_paths)} 个 'fileindex.txt' 文件:")
|
||||
# for i, path in enumerate(found_paths, 1):
|
||||
# print(f"{i}. {path}")
|
||||
|
||||
# lines = scanner.read_txt_by_line(full_path=found_paths[0])
|
||||
# for i, line in enumerate(lines, 1):
|
||||
# print(f"{i}. {line}")
|
||||
|
||||
output_dir = "D:/devForBdzlWork/ai-train_platform/predictions"
|
||||
scanner.upload_directory(output_dir, config['share'], remote_dir="西南计算机/AA县/报送数据_识别")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
243
util/yolo2pix_new.py
Normal file
243
util/yolo2pix_new.py
Normal file
@ -0,0 +1,243 @@
|
||||
import os
|
||||
import zipfile
|
||||
import shutil
|
||||
import cv2
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import smb
|
||||
|
||||
# ---------------- 常量 ----------------
|
||||
CELL_AREA = 0.0036 # 每格面积 (平方米)
|
||||
GRID_WIDTH = 108 # 网格像素宽
|
||||
GRID_HEIGHT = 102 # 网格像素高
|
||||
COVER_RATIO = 0.01 # mask 覆盖比例阈值
|
||||
|
||||
# ---------------- 路面类别映射 ----------------
|
||||
CLASS_MAP_ASPHALT = {
|
||||
"龟裂":0,"块状裂缝":1,"纵向裂缝":2,"横向裂缝":3,"沉陷":4,"车辙":5,"波浪拥包":6,"坑槽":7,"松散":8,"泛油":9,"修补":10
|
||||
}
|
||||
CLASS_MAP_CEMENT = {
|
||||
"破碎板":0,"裂缝":1,"板角断裂":2,"错台":3,"拱起":4,"边角剥落":5,"接缝料损坏":6,"坑洞":7,"唧泥":8,"露骨":9,"修补":10
|
||||
}
|
||||
CLASS_MAP_GRAVEL = {
|
||||
"坑槽":0,"沉陷":1,"车辙":2,"波浪搓板":3
|
||||
}
|
||||
|
||||
# ---------------- 工具函数 ----------------
|
||||
def num_to_coord(num, cols, cell_w, cell_h):
|
||||
n = num - 1
|
||||
r, c = divmod(n, cols)
|
||||
x1, y1 = c * cell_w, r * cell_h
|
||||
x2, y2 = x1 + cell_w, y1 + cell_h
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def draw_grid_on_image(image_path, grid_cells, cell_size=(GRID_WIDTH, GRID_HEIGHT), save_path=None):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None: return
|
||||
h, w = image.shape[:2]
|
||||
cell_w, cell_h = cell_size
|
||||
cols = w // cell_w
|
||||
overlay = image.copy()
|
||||
for cname, nums in grid_cells.items():
|
||||
color = (np.random.randint(64,255),np.random.randint(64,255),np.random.randint(64,255))
|
||||
for num in nums:
|
||||
x1,y1,x2,y2 = num_to_coord(num, cols, cell_w, cell_h)
|
||||
cv2.rectangle(overlay,(x1,y1),(x2,y2),color,-1)
|
||||
cv2.addWeighted(overlay,0.4,image,0.6,0,image)
|
||||
for i in range(0, w, cell_w):
|
||||
cv2.line(image,(i,0),(i,h),(100,100,100),1)
|
||||
for j in range(0, h, cell_h):
|
||||
cv2.line(image,(0,j),(w,j),(100,100,100),1)
|
||||
if save_path: cv2.imwrite(save_path,image)
|
||||
return image
|
||||
|
||||
def detect_road_type_from_content(label_file):
|
||||
"""根据标签内容判断路面类型"""
|
||||
try:
|
||||
with open(label_file,'r',encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except:
|
||||
return "gravel"
|
||||
for kw in CLASS_MAP_ASPHALT.keys():
|
||||
if kw in content: return "asphalt"
|
||||
for kw in CLASS_MAP_CEMENT.keys():
|
||||
if kw in content: return "cement"
|
||||
for kw in CLASS_MAP_GRAVEL.keys():
|
||||
if kw in content: return "gravel"
|
||||
return "gravel"
|
||||
|
||||
def yoloseg_to_grid(image_path,label_file,cover_ratio=COVER_RATIO):
|
||||
"""将YOLO-Seg标签转换成格子编号和类别"""
|
||||
road_type = detect_road_type_from_content(label_file)
|
||||
if road_type=="asphalt": class_map = CLASS_MAP_ASPHALT
|
||||
elif road_type=="cement": class_map = CLASS_MAP_CEMENT
|
||||
else: class_map = CLASS_MAP_GRAVEL
|
||||
class_names = list(class_map.keys())
|
||||
|
||||
img = cv2.imread(image_path)
|
||||
if img is None: return "", {}
|
||||
h, w = img.shape[:2]
|
||||
cols = max(1, w//GRID_WIDTH)
|
||||
rows = max(1, h//GRID_HEIGHT)
|
||||
|
||||
result_lines = []
|
||||
all_class_cells = {}
|
||||
with open(label_file,'r',encoding='utf-8') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts)<5: continue
|
||||
cls_id = int(parts[0])
|
||||
coords = [float(x) for x in parts[1:]]
|
||||
if len(coords)%2!=0: coords=coords[:-1]
|
||||
if len(coords)<6: continue
|
||||
poly = np.array(coords,dtype=np.float32).reshape(-1,2)
|
||||
poly[:,0]*=w
|
||||
poly[:,1]*=h
|
||||
mask = np.zeros((h,w),dtype=np.uint8)
|
||||
cv2.fillPoly(mask,[poly.astype(np.int32)],255)
|
||||
covered_cells=[]
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
x1,y1 = c*GRID_WIDTH, r*GRID_HEIGHT
|
||||
x2,y2 = min(w,x1+GRID_WIDTH), min(h,y1+GRID_HEIGHT)
|
||||
region = mask[y1:y2, x1:x2]
|
||||
if np.count_nonzero(region)/region.size>cover_ratio:
|
||||
covered_cells.append(r*cols+c+1)
|
||||
if not covered_cells: continue
|
||||
cname = class_names[cls_id] if cls_id<len(class_names) else str(cls_id)
|
||||
ids_str = '-'.join(map(str,sorted(covered_cells)))+'-'
|
||||
result_lines.append(f"{cname} {ids_str}")
|
||||
if cname not in all_class_cells: all_class_cells[cname]=set()
|
||||
all_class_cells[cname].update(covered_cells)
|
||||
return '\n'.join(result_lines), all_class_cells, road_type
|
||||
|
||||
def generate_header(road_type):
|
||||
if road_type=="asphalt": return "起点桩号(km),识别宽度(m),破损率DR(%),龟裂,块状裂缝,纵向裂缝,横向裂缝,沉陷,车辙,波浪拥包,坑槽,松散,泛油,修补"
|
||||
if road_type=="cement": return "起点桩号(km),识别宽度(m),破损率DR(%),破碎板,裂缝,板角断裂,错台,拱起,边角剥落,接缝料损坏,坑洞,唧泥,露骨,修补"
|
||||
if road_type=="gravel": return "起点桩号(km),识别宽度(m),破损率DR(%),坑槽,沉陷,车辙,波浪搓板"
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------- 主函数-共享目录 ----------------
|
||||
def process_dir(pile_dict,dir="output",cell_area=CELL_AREA,grid_width=GRID_WIDTH,grid_height=GRID_HEIGHT):
|
||||
os.makedirs(dir,exist_ok=True)
|
||||
# 解压
|
||||
# 读取桩号映射
|
||||
# 遍历图片
|
||||
summary_data = []
|
||||
for root,_,files in os.walk(dir):
|
||||
for file in files:
|
||||
if file.lower().endswith((".jpg",".png",".jpeg",".bmp")) :
|
||||
image_path = os.path.join(root,file)
|
||||
label_file = os.path.splitext(image_path)[0]+".txt"
|
||||
if not os.path.exists(label_file):
|
||||
print(f"⚠️ 找不到标签: {label_file}")
|
||||
continue
|
||||
out_txt, class_cells, road_type = yoloseg_to_grid(image_path,label_file)
|
||||
# 写每张图独立 _grid.txt
|
||||
grid_txt_path = os.path.splitext(image_path)[0]+"_grid.txt"
|
||||
with open(grid_txt_path,'w',encoding='utf-8') as f:
|
||||
f.write(out_txt)
|
||||
# 生成网格可视化
|
||||
draw_grid_on_image(image_path,class_cells,save_path=os.path.splitext(image_path)[0]+"_grid.jpg")
|
||||
# 统计各类面积
|
||||
counts = {k:len(v)*cell_area for k,v in class_cells.items()}
|
||||
total_area = sum(counts.values())
|
||||
# 桩号
|
||||
pile_no = pile_dict.get(file,"未知")
|
||||
# 破损率 DR (%) = total_area / 总面积
|
||||
DR = total_area/ (total_area if total_area>0 else 1) *100 # 简化为100%或者0
|
||||
summary_data.append((pile_no, DR, counts, road_type))
|
||||
|
||||
# 写桩号问题列表.txt
|
||||
if summary_data:
|
||||
road_type = summary_data[0][3]
|
||||
out_file = os.path.join(dir,"桩号问题列表.txt")
|
||||
header = generate_header(road_type)
|
||||
with open(out_file,'w',encoding='utf-8') as f:
|
||||
f.write(header+'\n')
|
||||
for pile_no,DR,counts,rt in summary_data:
|
||||
row = [pile_no,"3.6",f"{DR:.2f}"]
|
||||
if road_type=="asphalt":
|
||||
keys = list(CLASS_MAP_ASPHALT.keys())
|
||||
elif road_type=="cement":
|
||||
keys = list(CLASS_MAP_CEMENT.keys())
|
||||
else:
|
||||
keys = list(CLASS_MAP_GRAVEL.keys())
|
||||
for k in keys:
|
||||
row.append(f"{counts.get(k,0):.2f}")
|
||||
f.write(','.join(row)+'\n')
|
||||
print(f"✅ 输出完成: {out_file}")
|
||||
|
||||
# ---------------- 主函数 ----------------
|
||||
def process_zip(zip_path,pile_map_file,output_dir="output",cell_area=CELL_AREA,grid_width=GRID_WIDTH,grid_height=GRID_HEIGHT):
|
||||
if not os.path.exists(zip_path):
|
||||
raise FileNotFoundError(f"{zip_path} 不存在")
|
||||
os.makedirs(output_dir,exist_ok=True)
|
||||
# 解压
|
||||
with zipfile.ZipFile(zip_path,'r') as zip_ref:
|
||||
zip_ref.extractall(output_dir)
|
||||
|
||||
# 读取桩号映射
|
||||
pile_dict = {}
|
||||
with open(pile_map_file,'r',encoding='utf-8') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split("->")
|
||||
if len(parts)>=4:
|
||||
pile_dict[parts[3]]=parts[1] # filename -> 桩号
|
||||
|
||||
# 遍历图片
|
||||
summary_data = []
|
||||
for root,_,files in os.walk(output_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith((".jpg",".png",".jpeg",".bmp")) :
|
||||
image_path = os.path.join(root,file)
|
||||
label_file = os.path.splitext(image_path)[0]+".txt"
|
||||
if not os.path.exists(label_file):
|
||||
print(f"⚠️ 找不到标签: {label_file}")
|
||||
continue
|
||||
out_txt, class_cells, road_type = yoloseg_to_grid(image_path,label_file)
|
||||
# 写每张图独立 _grid.txt
|
||||
grid_txt_path = os.path.splitext(image_path)[0]+"_grid.txt"
|
||||
with open(grid_txt_path,'w',encoding='utf-8') as f:
|
||||
f.write(out_txt)
|
||||
# 生成网格可视化
|
||||
draw_grid_on_image(image_path,class_cells,save_path=os.path.splitext(image_path)[0]+"_grid.jpg")
|
||||
# 统计各类面积
|
||||
counts = {k:len(v)*cell_area for k,v in class_cells.items()}
|
||||
total_area = sum(counts.values())
|
||||
# 桩号
|
||||
pile_no = pile_dict.get(file,"未知")
|
||||
# 破损率 DR (%) = total_area / 总面积
|
||||
DR = total_area/ (total_area if total_area>0 else 1) *100 # 简化为100%或者0
|
||||
summary_data.append((pile_no, DR, counts, road_type))
|
||||
|
||||
# 写桩号问题列表.txt
|
||||
if summary_data:
|
||||
road_type = summary_data[0][3]
|
||||
out_file = os.path.join(output_dir,"桩号问题列表.txt")
|
||||
header = generate_header(road_type)
|
||||
with open(out_file,'w',encoding='utf-8') as f:
|
||||
f.write(header+'\n')
|
||||
for pile_no,DR,counts,rt in summary_data:
|
||||
row = [pile_no,"3.6",f"{DR:.2f}"]
|
||||
if road_type=="asphalt":
|
||||
keys = list(CLASS_MAP_ASPHALT.keys())
|
||||
elif road_type=="cement":
|
||||
keys = list(CLASS_MAP_CEMENT.keys())
|
||||
else:
|
||||
keys = list(CLASS_MAP_GRAVEL.keys())
|
||||
for k in keys:
|
||||
row.append(f"{counts.get(k,0):.2f}")
|
||||
f.write(','.join(row)+'\n')
|
||||
print(f"✅ 输出完成: {out_file}")
|
||||
|
||||
# ---------------- 示例调用 ----------------
|
||||
if __name__=="__main__":
|
||||
# zip_path = "D:/devForBdzlWork/ai-train_platform/predict/inferenceResult.zip" # 输入 ZIP 文件
|
||||
# pile_map_file = "D:/devForBdzlWork/ai-train_platform/predict/pile_map.txt" # 图片名 -> 桩号
|
||||
# process_zip(zip_path=zip_path,pile_map_file=pile_map_file,output_dir="output")
|
||||
|
||||
output_dir = "D:/devForBdzlWork/ai-train_platform/predictions/1"
|
||||
pile_dict = smb.get_pile_dict("192.168.110.114/share_File/西南计算机", "administrator", "abc@1234")
|
||||
process_dir(pile_dict, output_dir)
|
||||
47
yolo_api.py
47
yolo_api.py
@ -178,6 +178,7 @@ from pathlib import Path
|
||||
|
||||
from download_train import download_train
|
||||
from predict.predict_yolo11seg import predict_images
|
||||
from predict.predict_yolo11seg import predict_images_share_dir
|
||||
from query_process_status import get_process_status
|
||||
|
||||
|
||||
@ -487,6 +488,52 @@ async def query_train_task(request: Request):
|
||||
|
||||
|
||||
|
||||
# 接收前端实时流,进行任务推理-共享目录
|
||||
@app.post("/ai/project/inference4ShareDir")
|
||||
async def start_inference_share_dir(request):
|
||||
try:
|
||||
# 解析并验证请求数据
|
||||
request_json = request.json
|
||||
task_id = request_json["task_id"]
|
||||
pt_name = request_json["pt_name"]
|
||||
zip_url = request_json["zip_url"]
|
||||
user_name = request_json["user_name"]
|
||||
pwd = request_json["pwd"]
|
||||
time_ns = time.time_ns()
|
||||
# pt_name = f"{time_ns}-{task_id}.pt"
|
||||
# model_path=r"pt_save\best.pt"
|
||||
print(f"task_id {task_id}")
|
||||
|
||||
if user_name == "":
|
||||
user_name = "administrator"
|
||||
if pwd == "":
|
||||
pwd = "abc@1234"
|
||||
|
||||
output_dir = f"predictions/{task_id}"
|
||||
inference_zip_url,message=predict_images_share_dir(pt_name, zip_url, user_name, pwd, output_dir=output_dir, conf_threshold=0.25, save_json=False)
|
||||
if inference_zip_url:
|
||||
return response.json({
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
"inference_zip_url":inference_zip_url,
|
||||
"message": "predict request successfully"
|
||||
})
|
||||
else:
|
||||
return response.json({
|
||||
"status": "fail",
|
||||
"task_id": task_id,
|
||||
"inference_zip_url":inference_zip_url,
|
||||
"message": message
|
||||
})
|
||||
except ValueError as e:
|
||||
print(f"Validation error: {str(e)}")
|
||||
return response.json({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {str(e)}")
|
||||
return response.json({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
|
||||
|
||||
# 接收前端实时流,进行任务推理
|
||||
@app.post("/ai/project/inference")
|
||||
async def start_inference(request):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user