209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
|
|
import os
|
|||
|
|
import pandas as pd
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
import rasterio
|
|||
|
|
from PIL import Image
|
|||
|
|
from collections import defaultdict
|
|||
|
|
import numpy as np
|
|||
|
|
Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告
|
|||
|
|
class PredictionAggregator:
|
|||
|
|
def __init__(self, img_dir=None):
|
|||
|
|
"""
|
|||
|
|
img_dir: 原始图像路径(用于自动获取 full_size)
|
|||
|
|
"""
|
|||
|
|
self.preds = defaultdict(list)
|
|||
|
|
self.sizes = dict()
|
|||
|
|
self.img_dir = img_dir
|
|||
|
|
|
|||
|
|
def _get_image_size(self, fname):
|
|||
|
|
"""
|
|||
|
|
自动根据图像格式获取 (H, W)
|
|||
|
|
"""
|
|||
|
|
path = os.path.join(self.img_dir, fname)
|
|||
|
|
ext = os.path.splitext(fname)[1].lower()
|
|||
|
|
if ext in ['.tif', '.tiff']:
|
|||
|
|
with rasterio.open(path) as src:
|
|||
|
|
return src.height, src.width # rasterio: (H, W)
|
|||
|
|
else:
|
|||
|
|
with Image.open(path) as img:
|
|||
|
|
return img.size[1], img.size[0] # PIL: (W, H) → (H, W)
|
|||
|
|
|
|||
|
|
def add_patch(self, fname, x, y, pred_patch, valid_h, valid_w):
|
|||
|
|
"""
|
|||
|
|
添加一个 patch 到对应位置(自动裁去 padding)
|
|||
|
|
fname: 原始图像名(如 image_1.png)
|
|||
|
|
x, y: patch 左上角坐标
|
|||
|
|
pred_patch: numpy array, shape=[H_pad, W_pad]
|
|||
|
|
valid_h, valid_w: 原始补零前 patch 的真实高度和宽度
|
|||
|
|
"""
|
|||
|
|
if fname not in self.sizes:
|
|||
|
|
if self.img_dir is None:
|
|||
|
|
raise ValueError("img_dir 未指定,无法自动获取图像尺寸")
|
|||
|
|
self.sizes[fname] = self._get_image_size(fname)
|
|||
|
|
|
|||
|
|
# 裁剪有效区域(去掉 padding)
|
|||
|
|
patch_crop = pred_patch[:valid_h, :valid_w]
|
|||
|
|
self.preds[fname].append((x, y, patch_crop))
|
|||
|
|
|
|||
|
|
def save_all(self, save_dir):
|
|||
|
|
os.makedirs(save_dir, exist_ok=True)
|
|||
|
|
out_save_path = ""
|
|||
|
|
for fname, patches in self.preds.items():
|
|||
|
|
H, W = self.sizes[fname]
|
|||
|
|
canvas = np.zeros((H, W), dtype=np.uint8)
|
|||
|
|
|
|||
|
|
for x, y, patch in patches:
|
|||
|
|
h, w = patch.shape
|
|||
|
|
canvas[y:y+h, x:x+w] = patch
|
|||
|
|
|
|||
|
|
save_name = f"{fname}" # 保留原图扩展名
|
|||
|
|
save_path = os.path.join(save_dir, save_name)
|
|||
|
|
Image.fromarray(canvas).save(save_path)
|
|||
|
|
out_save_path=save_path
|
|||
|
|
return out_save_path
|
|||
|
|
def save_one_pic(self, save_dir):
|
|||
|
|
os.makedirs(save_dir, exist_ok=True)
|
|||
|
|
out_save_path=""
|
|||
|
|
for fname, patches in self.preds.items():
|
|||
|
|
H, W = self.sizes[fname]
|
|||
|
|
canvas = np.zeros((H, W), dtype=np.uint8)
|
|||
|
|
|
|||
|
|
for x, y, patch in patches:
|
|||
|
|
h, w = patch.shape
|
|||
|
|
canvas[y:y+h, x:x+w] = patch
|
|||
|
|
|
|||
|
|
save_name = f"{fname}" # 保留原图扩展名
|
|||
|
|
save_path = os.path.join(save_dir, save_name)
|
|||
|
|
Image.fromarray(canvas).save(save_path)
|
|||
|
|
out_save_path=save_path
|
|||
|
|
self.preds.clear()
|
|||
|
|
return out_save_path
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# class PredictionAggregator:
|
|||
|
|
# def __init__(self, img_dir=None):
|
|||
|
|
# """
|
|||
|
|
# img_dir: 原始图像路径(用于自动获取 full_size)
|
|||
|
|
# """
|
|||
|
|
# self.preds = defaultdict(list)
|
|||
|
|
# self.sizes = dict()
|
|||
|
|
# self.img_dir = img_dir
|
|||
|
|
|
|||
|
|
# def _get_image_size(self, fname):
|
|||
|
|
# """
|
|||
|
|
# 自动根据图像格式获取 (H, W)
|
|||
|
|
# """
|
|||
|
|
# path = os.path.join(self.img_dir, fname)
|
|||
|
|
# ext = os.path.splitext(fname)[1].lower()
|
|||
|
|
# if ext in ['.tif', '.tiff']:
|
|||
|
|
# with rasterio.open(path) as src:
|
|||
|
|
# return src.height, src.width # rasterio: (H, W)
|
|||
|
|
# else:
|
|||
|
|
# with Image.open(path) as img:
|
|||
|
|
# return img.size[1], img.size[0] # PIL: (W, H) → (H, W)
|
|||
|
|
|
|||
|
|
# def add_patch(self, fname, x, y, pred_patch):
|
|||
|
|
# """
|
|||
|
|
# 添加一个 patch 到对应位置
|
|||
|
|
# fname: 原始图像名(如 image_1.png)
|
|||
|
|
# x, y: patch 左上角坐标
|
|||
|
|
# pred_patch: numpy array, shape=[H, W]
|
|||
|
|
# """
|
|||
|
|
# if fname not in self.sizes:
|
|||
|
|
# if self.img_dir is None:
|
|||
|
|
# raise ValueError("img_dir 未指定,无法自动获取图像尺寸")
|
|||
|
|
# self.sizes[fname] = self._get_image_size(fname)
|
|||
|
|
|
|||
|
|
# self.preds[fname].append((x, y, pred_patch))
|
|||
|
|
|
|||
|
|
# def save_all(self, save_dir):
|
|||
|
|
# os.makedirs(save_dir, exist_ok=True)
|
|||
|
|
# for fname, patches in self.preds.items():
|
|||
|
|
# H, W = self.sizes[fname]
|
|||
|
|
# canvas = np.zeros((H, W), dtype=np.uint8)
|
|||
|
|
|
|||
|
|
# for x, y, patch in patches:
|
|||
|
|
# h, w = patch.shape
|
|||
|
|
# canvas[y:y+h, x:x+w] = patch
|
|||
|
|
|
|||
|
|
# save_name = f"pred_{fname}" # 保留原图扩展名
|
|||
|
|
# save_path = os.path.join(save_dir, save_name)
|
|||
|
|
# Image.fromarray(canvas).save(save_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PatchIndexer:
|
|||
|
|
def __init__(self, img_dir, patch_size=256, stride=256, save_dir=None):
|
|||
|
|
self.img_dir = img_dir
|
|||
|
|
self.patch_size = patch_size
|
|||
|
|
self.stride = stride
|
|||
|
|
self.coord_records = []
|
|||
|
|
|
|||
|
|
if save_dir is None:
|
|||
|
|
self.save_dir = os.path.dirname(os.path.abspath(img_dir))
|
|||
|
|
else:
|
|||
|
|
self.save_dir = save_dir
|
|||
|
|
|
|||
|
|
def index_all(self, save_name=None):
|
|||
|
|
filelist = sorted(f for f in os.listdir(self.img_dir)
|
|||
|
|
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')))
|
|||
|
|
print(f"[INFO] Found {len(filelist)} images.")
|
|||
|
|
|
|||
|
|
for fname in tqdm(filelist, desc='Indexing patches'):
|
|||
|
|
self.index_one(fname)
|
|||
|
|
|
|||
|
|
if save_name is None:
|
|||
|
|
folder_name = os.path.basename(self.save_dir.rstrip('/\\'))
|
|||
|
|
save_name = f'coords.csv'
|
|||
|
|
|
|||
|
|
save_path = os.path.join(self.save_dir, save_name)
|
|||
|
|
|
|||
|
|
df = pd.DataFrame(self.coord_records)
|
|||
|
|
if df.empty:
|
|||
|
|
print("[⚠️] No patch coordinates were generated!")
|
|||
|
|
else:
|
|||
|
|
df.to_csv(save_path, index=False)
|
|||
|
|
print(f"[✅] Saved {len(df)} patch records to {save_path}")
|
|||
|
|
|
|||
|
|
def index_one(self, fname):
|
|||
|
|
path = os.path.join(self.img_dir, fname)
|
|||
|
|
name, ext = os.path.splitext(fname)
|
|||
|
|
ext = ext.lower()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if ext in ['.tif', '.tiff']:
|
|||
|
|
with rasterio.open(path) as src:
|
|||
|
|
H, W = src.height, src.width
|
|||
|
|
else:
|
|||
|
|
with Image.open(path) as img:
|
|||
|
|
W, H = img.size # 注意 PIL 是 (W, H)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] Failed to open {fname}: {e}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
ps, st = self.patch_size, self.stride
|
|||
|
|
|
|||
|
|
if H < ps or W < ps:
|
|||
|
|
print(f"[WARNING] Skip {fname}: too small ({W}x{H})")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
y_list = list(range(0, H - ps + 1, st))
|
|||
|
|
x_list = list(range(0, W - ps + 1, st))
|
|||
|
|
|
|||
|
|
if (H - ps) % st != 0 and (H - ps) not in y_list:
|
|||
|
|
y_list.append(H - ps)
|
|||
|
|
if (W - ps) % st != 0 and (W - ps) not in x_list:
|
|||
|
|
x_list.append(W - ps)
|
|||
|
|
|
|||
|
|
for y in y_list:
|
|||
|
|
for x in x_list:
|
|||
|
|
self.coord_records.append({
|
|||
|
|
'orig_name': fname,
|
|||
|
|
'x': x,
|
|||
|
|
'y': y,
|
|||
|
|
'h': ps,
|
|||
|
|
'w': ps,
|
|||
|
|
'patch_name': f"{name}_{x}_{y}{ext}"
|
|||
|
|
})
|