939 lines
33 KiB
Python
939 lines
33 KiB
Python
import logging
|
||
import shutil
|
||
import zipfile
|
||
from os.path import exists
|
||
|
||
import torch
|
||
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
import time
|
||
import glob
|
||
from typing import List, Tuple, Optional, Dict
|
||
from ultralytics import YOLO
|
||
import matplotlib.pyplot as plt
|
||
|
||
from ultralytics import YOLO
|
||
|
||
from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder
|
||
from util.yolo2pix_new import *
|
||
from util.smb import *
|
||
import threading
|
||
|
||
# 定义红白蓝颜色 (BGR格式)
|
||
RED = (0, 0, 255)
|
||
WHITE = (255, 255, 255)
|
||
BLUE = (255, 0, 0)
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def unzip_file(zip_path, extract_dir):
|
||
"""解压ZIP文件"""
|
||
try:
|
||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||
zip_ref.extractall(extract_dir)
|
||
logger.info(f"文件已解压到: {extract_dir}")
|
||
except zipfile.BadZipFile:
|
||
logger.error("无效的ZIP文件")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"解压文件时出错: {e}")
|
||
raise
|
||
|
||
|
||
def create_result_zip(result_dir, output_path):
|
||
"""将结果目录压缩为ZIP文件"""
|
||
try:
|
||
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||
for root, _, files in os.walk(result_dir):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
arcname = os.path.relpath(file_path, result_dir)
|
||
zipf.write(file_path, arcname)
|
||
logger.info(f"结果已压缩到: {output_path}")
|
||
return output_path
|
||
except Exception as e:
|
||
logger.error(f"压缩结果时出错: {e}")
|
||
raise
|
||
|
||
|
||
def get_image_paths(folder_path):
|
||
"""获取文件夹中所有图片路径"""
|
||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif')
|
||
image_paths = []
|
||
for root, _, files in os.walk(folder_path):
|
||
for file in files:
|
||
if file.lower().endswith(image_extensions):
|
||
image_paths.append(os.path.join(root, file))
|
||
return image_paths
|
||
|
||
|
||
class InferenceResult:
|
||
"""存储推理结果的类"""
|
||
|
||
def __init__(self, image_path: str):
|
||
self.image_path = image_path
|
||
self.original_image = None # 原始图片 (RGB格式)
|
||
self.result_image = None # 绘制结果后的图片
|
||
self.masks = [] # 分割掩码列表
|
||
self.boxes = [] # 边界框列表
|
||
self.classes = [] # 类别ID列表
|
||
self.scores = [] # 置信度列表
|
||
self.class_names = [] # 类别名称列表
|
||
self.inference_time = 0.0 # 推理时间(秒)
|
||
|
||
|
||
class YOLOSegmentationInference:
|
||
"""YOLO分割模型推理工具"""
|
||
|
||
def __init__(self, model_path: str, device: Optional[str] = None):
|
||
"""
|
||
初始化推理工具
|
||
|
||
Args:
|
||
model_path: 模型文件路径
|
||
device: 运行设备 ('cpu', 'cuda', 或 None)
|
||
"""
|
||
self.model_path = model_path
|
||
self.device = device
|
||
self.model = None
|
||
self.class_names = []
|
||
|
||
# 定义颜色映射(用于不同类别)
|
||
self.colors = [
|
||
(255, 0, 0), # 红色
|
||
(0, 255, 0), # 绿色
|
||
(0, 0, 255), # 蓝色
|
||
(255, 255, 0), # 黄色
|
||
(255, 0, 255), # 品红色
|
||
(0, 255, 255), # 青色
|
||
(128, 0, 0), # 深红色
|
||
(0, 128, 0), # 深绿色
|
||
(0, 0, 128), # 深蓝色
|
||
(128, 128, 0), # 深黄色
|
||
(128, 0, 128), # 深品红色
|
||
(0, 128, 128), # 深青色
|
||
(192, 192, 192), # 灰色
|
||
(128, 128, 128), # 深灰色
|
||
(64, 64, 64), # 浅灰色
|
||
]
|
||
|
||
def load_model(self) -> bool:
|
||
"""
|
||
加载YOLO分割模型
|
||
|
||
Returns:
|
||
加载成功返回True,否则返回False
|
||
"""
|
||
try:
|
||
print(f"正在加载模型: {self.model_path}")
|
||
self.model = YOLO(self.model_path)
|
||
|
||
# 设置设备
|
||
if self.device is not None:
|
||
self.model.to(self.device)
|
||
elif self.model.device.type == 'cuda':
|
||
print("使用GPU加速")
|
||
else:
|
||
print("使用CPU")
|
||
|
||
# 获取类别名称
|
||
self.class_names = list(self.model.names.values())
|
||
print(f"模型加载成功,包含 {len(self.class_names)} 个类别")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"加载模型失败: {e}")
|
||
return False
|
||
|
||
def preprocess_image(self, image_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||
"""
|
||
图片预处理
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
|
||
Returns:
|
||
预处理后的图片和原始图片 (RGB格式)
|
||
"""
|
||
try:
|
||
# 检查路径是否为文件
|
||
if not os.path.isfile(image_path):
|
||
raise ValueError(f"路径 {image_path} 不是有效的文件")
|
||
|
||
# 读取图片
|
||
original_image = cv2.imread(image_path)
|
||
if original_image is None:
|
||
raise ValueError(f"无法读取图片: {image_path}")
|
||
|
||
# 转换为RGB格式
|
||
original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
||
|
||
return original_image_rgb, original_image_rgb
|
||
|
||
except PermissionError:
|
||
print(f"权限错误: 无法访问文件 {image_path}")
|
||
return None, None
|
||
except Exception as e:
|
||
print(f"图片预处理失败: {e}")
|
||
return None, None
|
||
|
||
def perform_inference_share_dir(self, image, image_path, conf_threshold: float = 0.25,
|
||
iou_threshold: float = 0.5) -> InferenceResult:
|
||
"""
|
||
执行推理
|
||
|
||
Args:
|
||
image_path: 图片数据
|
||
image_path: 图片路径
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
|
||
Returns:
|
||
推理结果
|
||
"""
|
||
result = InferenceResult(image_path)
|
||
|
||
try:
|
||
if self.model is None:
|
||
raise ValueError("模型未加载,请先调用load_model()")
|
||
|
||
# 转换为RGB格式
|
||
original_image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
||
result.original_image = original_image_rgb
|
||
|
||
# 执行推理
|
||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||
start_time = time.time()
|
||
|
||
# 使用YOLO模型进行推理
|
||
# predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
||
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0]
|
||
|
||
result.inference_time = time.time() - start_time
|
||
|
||
# 处理结果
|
||
if predictions.masks is not None:
|
||
# 处理掩码
|
||
masks = predictions.masks.data.cpu().numpy()
|
||
|
||
# 处理边界框
|
||
boxes = predictions.boxes.data.cpu().numpy()
|
||
|
||
# 处理类别和置信度
|
||
class_ids = predictions.boxes.cls.cpu().numpy().astype(int)
|
||
scores = predictions.boxes.conf.cpu().numpy()
|
||
|
||
# 获取类别名称
|
||
class_names = [self.model.names[i] for i in class_ids]
|
||
|
||
# 存储结果
|
||
result.masks = masks
|
||
result.boxes = boxes
|
||
result.classes = class_ids
|
||
result.scores = scores
|
||
result.class_names = class_names
|
||
|
||
print(f"检测到 {len(masks)} 个对象,推理时间: {result.inference_time:.3f} 秒")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"推理失败: {e}")
|
||
return result
|
||
|
||
def perform_inference(self, image_path: str, conf_threshold: float = 0.25,
|
||
iou_threshold: float = 0.5) -> InferenceResult:
|
||
"""
|
||
执行推理
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
|
||
Returns:
|
||
推理结果
|
||
"""
|
||
result = InferenceResult(image_path)
|
||
|
||
try:
|
||
if self.model is None:
|
||
raise ValueError("模型未加载,请先调用load_model()")
|
||
|
||
# 读取图片
|
||
original_image_rgb, _ = self.preprocess_image(image_path)
|
||
if original_image_rgb is None:
|
||
return result
|
||
|
||
result.original_image = original_image_rgb
|
||
|
||
# 执行推理
|
||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||
start_time = time.time()
|
||
|
||
# 使用YOLO模型进行推理
|
||
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
||
|
||
result.inference_time = time.time() - start_time
|
||
|
||
# 处理结果
|
||
if predictions.masks is not None:
|
||
# 处理掩码
|
||
masks = predictions.masks.data.cpu().numpy()
|
||
|
||
# 处理边界框
|
||
boxes = predictions.boxes.data.cpu().numpy()
|
||
|
||
# 处理类别和置信度
|
||
class_ids = predictions.boxes.cls.cpu().numpy().astype(int)
|
||
scores = predictions.boxes.conf.cpu().numpy()
|
||
|
||
# 获取类别名称
|
||
class_names = [self.model.names[i] for i in class_ids]
|
||
|
||
# 存储结果
|
||
result.masks = masks
|
||
result.boxes = boxes
|
||
result.classes = class_ids
|
||
result.scores = scores
|
||
result.class_names = class_names
|
||
|
||
print(f"检测到 {len(masks)} 个对象,推理时间: {result.inference_time:.3f} 秒")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"推理失败: {e}")
|
||
return result
|
||
|
||
def draw_results(self, result: InferenceResult, conf_threshold: float = 0.25) -> Optional[np.ndarray]:
|
||
"""
|
||
绘制推理结果
|
||
|
||
Args:
|
||
result: 推理结果
|
||
conf_threshold: 置信度阈值
|
||
|
||
Returns:
|
||
绘制结果后的图片
|
||
"""
|
||
if result.original_image is None or (result.masks is None or len(result.masks) == 0):
|
||
return result.original_image
|
||
|
||
# 创建副本
|
||
output_image = result.original_image.copy()
|
||
height, width = output_image.shape[:2]
|
||
|
||
# 绘制掩码和边界框
|
||
for i in range(len(result.masks)):
|
||
if result.scores[i] < conf_threshold:
|
||
continue
|
||
|
||
class_id = result.classes[i]
|
||
class_name = result.class_names[i]
|
||
score = result.scores[i]
|
||
|
||
# 获取颜色
|
||
color = self.colors[class_id % len(self.colors)]
|
||
|
||
# 绘制掩码
|
||
mask = result.masks[i]
|
||
mask = cv2.resize(mask, (width, height))
|
||
mask = (mask > 0.5).astype(np.uint8) * 255
|
||
|
||
# 创建掩码彩色图层
|
||
mask_colored = np.zeros_like(output_image)
|
||
mask_colored[mask > 0] = color
|
||
|
||
# 混合掩码和原图
|
||
output_image = cv2.addWeighted(output_image, 0.7, mask_colored, 0.3, 0)
|
||
|
||
# 绘制边界框
|
||
if len(result.boxes) > i:
|
||
box = result.boxes[i]
|
||
x1, y1, x2, y2 = box[:4].astype(int)
|
||
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 添加标签
|
||
label = f"{class_name}: {score:.2f}"
|
||
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||
cv2.rectangle(output_image, (x1, y1 - text_height),
|
||
(x1 + text_width, y1), color, -1)
|
||
cv2.putText(output_image, label, (x1, y1),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||
|
||
result.result_image = output_image
|
||
return output_image
|
||
|
||
def save_results(self, result: InferenceResult, output_dir: str, save_mask: bool = False,
|
||
save_label: bool = False, result_save: [] = None) -> None:
|
||
"""
|
||
保存推理结果
|
||
|
||
Args:
|
||
result: 推理结果
|
||
output_dir: 输出目录
|
||
save_mask: 是否保存单独的掩码文件
|
||
save_label: 是否保存YOLO格式的标签文件
|
||
"""
|
||
if result.result_image is None:
|
||
return
|
||
|
||
try:
|
||
base_name = os.path.splitext(os.path.basename(result.image_path))[0]
|
||
# output_dir = output_dir + "/" + base_name
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# result_save.append()
|
||
redir_obj = {
|
||
"origin_img": result.image_path,
|
||
"label": None,
|
||
"result_dir": output_dir
|
||
}
|
||
# 保存结果图片
|
||
result_path = os.path.join(output_dir, f"{base_name}.jpg")
|
||
result_image_bgr = cv2.cvtColor(result.result_image, cv2.COLOR_RGB2BGR)
|
||
cv2.imwrite(result_path, result_image_bgr)
|
||
print(f"结果图片已保存: {result_path}")
|
||
|
||
# 保存单独的掩码文件
|
||
if save_mask and result.masks is not None and len(result.masks) > 0:
|
||
# mask_dir = os.path.join(output_dir, "masks")
|
||
mask_dir = output_dir
|
||
os.makedirs(mask_dir, exist_ok=True)
|
||
|
||
for i in range(len(result.masks)):
|
||
mask = result.masks[i]
|
||
mask = cv2.resize(mask, (result.original_image.shape[1], result.original_image.shape[0]))
|
||
mask = (mask > 0.5).astype(np.uint8) * 255
|
||
|
||
mask_path = os.path.join(mask_dir, f"{base_name}_mask_{i}_{result.class_names[i]}.png")
|
||
cv2.imwrite(mask_path, mask)
|
||
|
||
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
|
||
|
||
# 保存YOLO格式的标签文件
|
||
if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0:
|
||
# label_dir = os.path.join(output_dir, "labels")
|
||
label_dir = output_dir
|
||
os.makedirs(label_dir, exist_ok=True)
|
||
|
||
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
||
|
||
with open(label_path, 'w') as f:
|
||
for i in range(len(result.masks)):
|
||
class_id = result.classes[i]
|
||
score = result.scores[i]
|
||
mask = result.masks[i]
|
||
|
||
# 获取掩码的多边形轮廓
|
||
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8),
|
||
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
||
if contours:
|
||
# 取最大的轮廓
|
||
contour = max(contours, key=cv2.contourArea)
|
||
|
||
# 简化轮廓
|
||
epsilon = 0.001 * cv2.arcLength(contour, True)
|
||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||
|
||
# 归一化坐标
|
||
h, w = mask.shape
|
||
points = []
|
||
for point in approx:
|
||
x = point[0][0] / w
|
||
y = point[0][1] / h
|
||
points.extend([x, y])
|
||
|
||
# 写入标签文件
|
||
if len(points) >= 6: # 至少3个点
|
||
line = f"{class_id} {' '.join(map(lambda x: f'{x:.6f}', points))} {score:.6f}\n"
|
||
f.write(line)
|
||
|
||
print(f"标签文件已保存: {label_path}")
|
||
redir_obj["label"] = label_path
|
||
result_save.append(redir_obj)
|
||
|
||
|
||
except PermissionError:
|
||
print(f"权限错误: 无法写入到目录 {output_dir}")
|
||
except Exception as e:
|
||
print(f"保存结果失败: {e}")
|
||
|
||
def show_results(self, result: InferenceResult) -> None:
|
||
"""
|
||
显示推理结果
|
||
|
||
Args:
|
||
result: 推理结果
|
||
"""
|
||
if result.result_image is None:
|
||
return
|
||
|
||
# 设置中文字体 - 支持多种常见中文字体
|
||
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'SimHei', 'Microsoft YaHei',
|
||
'Arial Unicode MS', 'sans-serif']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
plt.figure(figsize=(12, 8))
|
||
plt.imshow(result.result_image)
|
||
plt.axis('off')
|
||
|
||
# 添加标题
|
||
title = f"推理结果: {os.path.basename(result.image_path)}\n"
|
||
title += f"检测到 {len(result.masks)} 个对象,推理时间: {result.inference_time:.3f} 秒"
|
||
plt.title(title)
|
||
|
||
plt.show()
|
||
|
||
def process_single_image(self, image_path: str, output_dir: Optional[str] = None,
|
||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||
save_mask: bool = False, save_label: bool = False, show: bool = True,
|
||
result_save: [] = None) -> InferenceResult:
|
||
"""
|
||
处理单张图片
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
output_dir: 输出目录,如果为None则不保存
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
save_mask: 是否保存单独的掩码文件
|
||
save_label: 是否保存YOLO格式的标签文件
|
||
show: 是否显示结果
|
||
|
||
Returns:
|
||
推理结果
|
||
"""
|
||
# 执行推理
|
||
result = self.perform_inference(image_path, conf_threshold, iou_threshold)
|
||
|
||
# 绘制结果
|
||
if result.masks is not None and len(result.masks) > 0:
|
||
self.draw_results(result, conf_threshold)
|
||
|
||
# 保存结果
|
||
if output_dir is not None:
|
||
self.save_results(result, output_dir, save_mask, save_label, result_save)
|
||
|
||
# # 显示结果
|
||
# if show:
|
||
# self.show_results(result)
|
||
|
||
return result
|
||
|
||
def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None,
|
||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||
save_mask: bool = False, save_label: bool = False, show: bool = True,
|
||
result_save: [] = None) -> InferenceResult:
|
||
"""
|
||
处理单张图片
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
output_dir: 输出目录,如果为None则不保存
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
save_mask: 是否保存单独的掩码文件
|
||
save_label: 是否保存YOLO格式的标签文件
|
||
show: 是否显示结果
|
||
|
||
Returns:
|
||
推理结果
|
||
"""
|
||
# 执行推理
|
||
config = get_conf(image_path, user_name, pwd)
|
||
scanner = get_scanner(image_path, user_name=user_name, pwd=pwd)
|
||
image = scanner.read_img_file(image_path)
|
||
result = self.perform_inference_share_dir(image, image_path, conf_threshold, iou_threshold)
|
||
|
||
# 绘制结果
|
||
if result.masks is not None and len(result.masks) > 0:
|
||
self.draw_results(result, conf_threshold)
|
||
|
||
# 保存结果
|
||
if output_dir is not None:
|
||
self.save_results(result, output_dir, save_mask, save_label, result_save)
|
||
|
||
# # 显示结果
|
||
# if show:
|
||
# self.show_results(result)
|
||
|
||
return result
|
||
|
||
def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None,
|
||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||
save_mask: bool = False, save_label: bool = False, show: bool = False,
|
||
result_save: [] = None) -> List[
|
||
InferenceResult]:
|
||
"""
|
||
处理目录中的所有图片
|
||
|
||
Args:
|
||
input_dir: 输入目录
|
||
output_dir: 输出目录,如果为None则不保存
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
save_mask: 是否保存单独的掩码文件
|
||
save_label: 是否保存YOLO格式的标签文件
|
||
show: 是否显示结果
|
||
|
||
Returns:
|
||
推理结果列表
|
||
"""
|
||
results = []
|
||
|
||
try:
|
||
# 检查目录是否存在
|
||
if not os.path.isdir(input_dir):
|
||
print(f"错误: {input_dir} 不是有效的目录")
|
||
return results
|
||
|
||
# 获取所有图片文件
|
||
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
|
||
image_files = []
|
||
|
||
for ext in image_extensions:
|
||
image_files.extend(glob.glob(os.path.join(input_dir, ext)))
|
||
|
||
if not image_files:
|
||
print(f"在目录 {input_dir} 中未找到图片文件")
|
||
return results
|
||
|
||
print(f"找到 {len(image_files)} 个图片文件")
|
||
|
||
# 处理每张图片
|
||
for image_path in image_files:
|
||
result = self.process_single_image(
|
||
image_path=image_path,
|
||
output_dir=output_dir,
|
||
conf_threshold=conf_threshold,
|
||
iou_threshold=iou_threshold,
|
||
save_mask=save_mask,
|
||
save_label=save_label,
|
||
show=show,
|
||
result_save=result_save
|
||
)
|
||
results.append(result)
|
||
|
||
return results
|
||
|
||
except PermissionError:
|
||
print(f"权限错误: 无法访问目录 {input_dir}")
|
||
return results
|
||
except Exception as e:
|
||
print(f"处理目录失败: {e}")
|
||
return results
|
||
|
||
def process_image_directory_share_dir(self, task_id, input_dir, user_name, pwd, output_dir: Optional[str] = None,
|
||
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
|
||
save_mask: bool = False, save_label: bool = False, show: bool = False,
|
||
result_save: [] = None) -> List[
|
||
InferenceResult]:
|
||
"""
|
||
处理目录中的所有图片
|
||
|
||
Args:
|
||
input_dir: 输入目录
|
||
output_dir: 输出目录,如果为None则不保存
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
save_mask: 是否保存单独的掩码文件
|
||
save_label: 是否保存YOLO格式的标签文件
|
||
show: 是否显示结果
|
||
|
||
Returns:
|
||
推理结果列表
|
||
"""
|
||
results = []
|
||
|
||
try:
|
||
# 检查目录是否存在
|
||
config = get_conf(zip_url=input_dir, user_name=user_name, pwd=pwd)
|
||
scanner = get_scanner(zip_url=input_dir, user_name=user_name, pwd=pwd);
|
||
if not scanner.directory_exists(input_dir):
|
||
print(f"错误: {input_dir} 不是有效的目录")
|
||
return results
|
||
|
||
# 获取所有图片文件
|
||
image_files = scanner.get_smb_images(input_dir)
|
||
|
||
if not image_files:
|
||
print(f"在目录 {input_dir} 中未找到图片文件")
|
||
return results
|
||
|
||
print(f"找到 {len(image_files)} 个图片文件")
|
||
|
||
# 处理每张图片
|
||
for image_path in image_files:
|
||
result = self.process_single_image_share_dir(
|
||
image_path=image_path,
|
||
user_name=user_name,
|
||
pwd=pwd,
|
||
output_dir=output_dir,
|
||
conf_threshold=conf_threshold,
|
||
iou_threshold=iou_threshold,
|
||
save_mask=save_mask,
|
||
save_label=save_label,
|
||
show=show,
|
||
result_save=result_save
|
||
)
|
||
results.append(result)
|
||
|
||
# 推送识别数据到共享目录
|
||
tmpConfig = get_conf(input_dir, user_name, pwd)
|
||
pile_dict = get_pile_dict(input_dir, user_name, pwd)
|
||
road_dict = get_road_dict(f"{tmpConfig['ip']}/{tmpConfig['share']}", user_name, pwd)
|
||
process_dir(road_dict, pile_dict, output_dir)
|
||
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
|
||
remote_dir = f"{tmpConfig['dir']}_识别/{task_id}/{current_time}"
|
||
# scanner.upload_directory(output_dir, config['share'], remote_dir=remote_dir)
|
||
|
||
return results
|
||
|
||
except PermissionError:
|
||
print(f"权限错误: 无法访问目录 {input_dir}")
|
||
return results
|
||
except Exception as e:
|
||
print(f"处理目录失败: {e}")
|
||
return results
|
||
|
||
def predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False):
|
||
zip_save_path = "dataset/zip_file"
|
||
inference_dir = "dataset/inference_dir"
|
||
zip_dir = "dataset/zip_dir"
|
||
output_dir = "dataset/predictions"
|
||
result_pix = "dataset/pix"
|
||
|
||
# 本地测试模式 - 请根据实际情况修改以下路径
|
||
# local_model_path = r"D:\project\verification\ultralytics-main\model\script\seg\pt\test.pttest.pt"
|
||
local_model_path = r"D:\project\ai-train_platform\pt_save\road_crack.pt"
|
||
local_output_dir = output_dir
|
||
# zip_url = "meta_data/ai_train_platform/train.zip"
|
||
|
||
try:
|
||
|
||
zip_size = check_zip_size(zip_url)
|
||
zip_size = zip_size / 1024 / 1024 # 单位为M
|
||
|
||
if zip_size > 100:
|
||
return None, "zip_file is bigger than 20MB"
|
||
zip_local_path = downFile(zip_url)
|
||
time_ns = time.time_ns()
|
||
inference_dir = inference_dir + str(time_ns)
|
||
zip_dir = zip_dir + str(time_ns) + ".zip"
|
||
zip_local_dir = os.path.join(zip_save_path, str(time_ns))
|
||
zip_local_dir_save = os.path.join(zip_save_path, str(time_ns) + "save")
|
||
if not exists(zip_local_dir_save):
|
||
os.mkdir(zip_local_dir_save)
|
||
unzip_file(zip_local_path, zip_local_dir)
|
||
if os.path.exists(zip_local_path):
|
||
os.remove(zip_local_path)
|
||
pic_local_path_list = get_image_paths(zip_local_dir)
|
||
for pic_path in pic_local_path_list:
|
||
shutil.move(pic_path, zip_local_dir_save)
|
||
|
||
if exists(zip_local_dir): # 删除zip及相关文件夹
|
||
try:
|
||
if os.path.exists(zip_local_dir):
|
||
shutil.rmtree(zip_local_dir)
|
||
print(f"成功删除文件夹: {zip_local_dir}")
|
||
|
||
except Exception as e:
|
||
print(f"删除文件夹时发生错误: {e}")
|
||
|
||
# 加载模型
|
||
print(f"正在加载模型: {local_model_path}")
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
print(f"模型已加载到 {device}")
|
||
|
||
|
||
except Exception as e:
|
||
print(f"处理目录失败: {e}")
|
||
|
||
inference = YOLOSegmentationInference(
|
||
model_path=local_model_path,
|
||
device=device
|
||
)
|
||
|
||
# 加载模型
|
||
if not inference.load_model():
|
||
return
|
||
|
||
# input_path = r"D:\project\verification\ultralytics-main\model\script\seg\test_seg_pic"
|
||
input_path = zip_local_dir_save
|
||
result_save = []
|
||
|
||
conf_threshold = 0.25
|
||
iou_threshold = 0.5
|
||
save_mask = True,
|
||
save_label = True,
|
||
show = True
|
||
|
||
if os.path.isdir(input_path):
|
||
# 处理目录
|
||
inference.process_image_directory(
|
||
input_dir=input_path,
|
||
output_dir=output_dir,
|
||
conf_threshold=0.25,
|
||
iou_threshold=0.5,
|
||
save_mask=save_mask,
|
||
save_label=save_label,
|
||
show=show, result_save=result_save
|
||
)
|
||
else:
|
||
print(f"错误: 输入 {input_path} 不是有效的文件或目录")
|
||
file_save_dir_list = []
|
||
result_all_zip = os.path.join(zip_save_path, f"{time_ns}_all_zip")
|
||
if not exists(result_all_zip):
|
||
os.mkdir(result_all_zip)
|
||
|
||
if not os.path.exists(result_pix):
|
||
os.mkdir(result_pix)
|
||
|
||
result_pix = os.path.join(result_pix, str(time_ns))
|
||
if not os.path.exists(result_pix):
|
||
os.mkdir(result_pix)
|
||
classes = ['裂缝', '横向裂缝', '纵向裂缝', "修补", "坑洞"]
|
||
for redir_obj in result_save: # 到这里输出的都是标准格式,针对于泽光现场,做格式转换
|
||
# 使用 os.walk 递归遍历所有子目录
|
||
origin_img = redir_obj["origin_img"]
|
||
label = redir_obj["label"]
|
||
result_dir = redir_obj["result_dir"] #这里这个是分割的完整输出,未做二次处理
|
||
|
||
if label:
|
||
origin_img_name = os.path.basename(origin_img)
|
||
label_name = os.path.basename(label)
|
||
|
||
pix_result_dir = os.path.join(os.path.abspath(result_pix), os.path.splitext(label_name)[0])
|
||
if not exists(pix_result_dir):
|
||
os.mkdir(pix_result_dir)
|
||
shutil.move(origin_img, pix_result_dir)
|
||
shutil.move(label, pix_result_dir)
|
||
new_origin_img = os.path.join(pix_result_dir, origin_img_name)
|
||
new_mask_label = os.path.join(pix_result_dir, label_name)
|
||
shutil.rmtree(result_dir) # 删除所有源文件
|
||
|
||
output_str, class_cells = yoloseg_to_grid_cells_fixed_v5(
|
||
new_origin_img, new_mask_label, class_names=classes
|
||
)
|
||
|
||
# 写入 _grid.txt
|
||
out_txt_file = os.path.join(pix_result_dir, os.path.splitext(label_name)[0] + "_grid.txt")
|
||
with open(out_txt_file, 'w', encoding='utf-8') as f:
|
||
f.write(output_str)
|
||
|
||
# 绘制结果图
|
||
out_img_file = os.path.join(pix_result_dir, os.path.splitext(label_name)[0] + "_grid.jpg")
|
||
draw_grid_on_image(new_origin_img, class_cells, save_path=out_img_file)
|
||
if exists(new_mask_label):
|
||
os.remove(new_mask_label) #删除label文件,为了符合于泽光他们现场的需求
|
||
print("12121")
|
||
|
||
zip_dir_path = os.path.abspath(result_pix) + ".zip"
|
||
create_result_zip(result_pix, zip_dir_path)
|
||
file_save_dir, file_type = upload_file(zip_dir_path, None)
|
||
|
||
if os.path.exists(result_pix):
|
||
shutil.rmtree(result_pix)
|
||
|
||
if os.path.exists(zip_dir_path):
|
||
os.remove(zip_dir_path)
|
||
|
||
return file_save_dir, "success"
|
||
|
||
|
||
|
||
|
||
def predict_images_share_dir(task_id, pt_name, zip_url, user_name, pwd, output_dir="predictions", conf_threshold=0.25, save_json=False):
|
||
# 本地测试模式 - 请根据实际情况修改以下路径
|
||
# local_model_path = r"D:\project\verification\ultralytics-main\model\script\seg\pt\test.pttest.pt"
|
||
local_model_path = r"../pt_save/road_crack.pt"
|
||
local_output_dir = output_dir
|
||
# zip_url = "meta_data/ai_train_platform/train.zip"
|
||
|
||
try:
|
||
# 加载模型
|
||
print(f"正在加载模型: {local_model_path}")
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"模型已加载到 {device}")
|
||
except Exception as e:
|
||
print(f"处理目录失败: {e}")
|
||
|
||
inference = YOLOSegmentationInference(
|
||
model_path=local_model_path,
|
||
device=device
|
||
)
|
||
|
||
# 加载模型
|
||
if not inference.load_model():
|
||
return
|
||
|
||
|
||
|
||
# zip_url = r"D:\project\verification\ultralytics-main\model\script\seg\test_seg_pic"
|
||
|
||
result_save = []
|
||
conf_threshold = 0.25
|
||
iou_threshold = 0.5
|
||
save_mask = False
|
||
save_label = True
|
||
show = True
|
||
|
||
# 查找指定文件夹 图像类
|
||
config = get_conf(zip_url, user_name, pwd)
|
||
scanner = get_scanner(zip_url, user_name, pwd)
|
||
found_paths = scanner.find_folders_by_name(
|
||
share_path=config['share'],
|
||
folder_name='图像类',
|
||
start_dir=config['dir']
|
||
)
|
||
|
||
target_path = "" # 识别图片目录
|
||
flag_dir_path = "" # 标识目录
|
||
if len(found_paths) > 0:
|
||
# 处理目录
|
||
flag_dir_path = found_paths[0]
|
||
tmpConfig = get_conf(flag_dir_path, user_name, pwd)
|
||
found_paths = scanner.find_folders_by_name(
|
||
share_path=config['share'],
|
||
folder_name='Images',
|
||
start_dir=tmpConfig['dir']
|
||
)
|
||
if len(found_paths) > 0:
|
||
target_path = found_paths[0]
|
||
# inference.process_image_directory_share_dir(
|
||
# input_dir=target_path,
|
||
# user_name=user_name,
|
||
# pwd=pwd,
|
||
# output_dir=output_dir,
|
||
# conf_threshold=conf_threshold,
|
||
# iou_threshold=iou_threshold,
|
||
# save_mask=save_mask,
|
||
# save_label=save_label,
|
||
# show=show,
|
||
# result_save=result_save
|
||
# )
|
||
|
||
# 创建并启动线程
|
||
thread1 = threading.Thread(target=inference.process_image_directory_share_dir, args=(task_id, target_path,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save))
|
||
|
||
# 启动线程
|
||
thread1.start()
|
||
else:
|
||
print(f"错误: 输入 {zip_url} 不是有效的文件或目录")
|
||
|
||
return f"{target_path}_识别/{task_id}", "success" |