209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
import os
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
import rasterio
|
||
from PIL import Image
|
||
from collections import defaultdict
|
||
import numpy as np
|
||
Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告
|
||
class PredictionAggregator:
|
||
def __init__(self, img_dir=None):
|
||
"""
|
||
img_dir: 原始图像路径(用于自动获取 full_size)
|
||
"""
|
||
self.preds = defaultdict(list)
|
||
self.sizes = dict()
|
||
self.img_dir = img_dir
|
||
|
||
def _get_image_size(self, fname):
|
||
"""
|
||
自动根据图像格式获取 (H, W)
|
||
"""
|
||
path = os.path.join(self.img_dir, fname)
|
||
ext = os.path.splitext(fname)[1].lower()
|
||
if ext in ['.tif', '.tiff']:
|
||
with rasterio.open(path) as src:
|
||
return src.height, src.width # rasterio: (H, W)
|
||
else:
|
||
with Image.open(path) as img:
|
||
return img.size[1], img.size[0] # PIL: (W, H) → (H, W)
|
||
|
||
def add_patch(self, fname, x, y, pred_patch, valid_h, valid_w):
|
||
"""
|
||
添加一个 patch 到对应位置(自动裁去 padding)
|
||
fname: 原始图像名(如 image_1.png)
|
||
x, y: patch 左上角坐标
|
||
pred_patch: numpy array, shape=[H_pad, W_pad]
|
||
valid_h, valid_w: 原始补零前 patch 的真实高度和宽度
|
||
"""
|
||
if fname not in self.sizes:
|
||
if self.img_dir is None:
|
||
raise ValueError("img_dir 未指定,无法自动获取图像尺寸")
|
||
self.sizes[fname] = self._get_image_size(fname)
|
||
|
||
# 裁剪有效区域(去掉 padding)
|
||
patch_crop = pred_patch[:valid_h, :valid_w]
|
||
self.preds[fname].append((x, y, patch_crop))
|
||
|
||
def save_all(self, save_dir):
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
out_save_path = ""
|
||
for fname, patches in self.preds.items():
|
||
H, W = self.sizes[fname]
|
||
canvas = np.zeros((H, W), dtype=np.uint8)
|
||
|
||
for x, y, patch in patches:
|
||
h, w = patch.shape
|
||
canvas[y:y+h, x:x+w] = patch
|
||
|
||
save_name = f"{fname}" # 保留原图扩展名
|
||
save_path = os.path.join(save_dir, save_name)
|
||
Image.fromarray(canvas).save(save_path)
|
||
out_save_path=save_path
|
||
return out_save_path
|
||
def save_one_pic(self, save_dir):
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
out_save_path=""
|
||
for fname, patches in self.preds.items():
|
||
H, W = self.sizes[fname]
|
||
canvas = np.zeros((H, W), dtype=np.uint8)
|
||
|
||
for x, y, patch in patches:
|
||
h, w = patch.shape
|
||
canvas[y:y+h, x:x+w] = patch
|
||
|
||
save_name = f"{fname}" # 保留原图扩展名
|
||
save_path = os.path.join(save_dir, save_name)
|
||
Image.fromarray(canvas).save(save_path)
|
||
out_save_path=save_path
|
||
self.preds.clear()
|
||
return out_save_path
|
||
|
||
|
||
|
||
|
||
# class PredictionAggregator:
|
||
# def __init__(self, img_dir=None):
|
||
# """
|
||
# img_dir: 原始图像路径(用于自动获取 full_size)
|
||
# """
|
||
# self.preds = defaultdict(list)
|
||
# self.sizes = dict()
|
||
# self.img_dir = img_dir
|
||
|
||
# def _get_image_size(self, fname):
|
||
# """
|
||
# 自动根据图像格式获取 (H, W)
|
||
# """
|
||
# path = os.path.join(self.img_dir, fname)
|
||
# ext = os.path.splitext(fname)[1].lower()
|
||
# if ext in ['.tif', '.tiff']:
|
||
# with rasterio.open(path) as src:
|
||
# return src.height, src.width # rasterio: (H, W)
|
||
# else:
|
||
# with Image.open(path) as img:
|
||
# return img.size[1], img.size[0] # PIL: (W, H) → (H, W)
|
||
|
||
# def add_patch(self, fname, x, y, pred_patch):
|
||
# """
|
||
# 添加一个 patch 到对应位置
|
||
# fname: 原始图像名(如 image_1.png)
|
||
# x, y: patch 左上角坐标
|
||
# pred_patch: numpy array, shape=[H, W]
|
||
# """
|
||
# if fname not in self.sizes:
|
||
# if self.img_dir is None:
|
||
# raise ValueError("img_dir 未指定,无法自动获取图像尺寸")
|
||
# self.sizes[fname] = self._get_image_size(fname)
|
||
|
||
# self.preds[fname].append((x, y, pred_patch))
|
||
|
||
# def save_all(self, save_dir):
|
||
# os.makedirs(save_dir, exist_ok=True)
|
||
# for fname, patches in self.preds.items():
|
||
# H, W = self.sizes[fname]
|
||
# canvas = np.zeros((H, W), dtype=np.uint8)
|
||
|
||
# for x, y, patch in patches:
|
||
# h, w = patch.shape
|
||
# canvas[y:y+h, x:x+w] = patch
|
||
|
||
# save_name = f"pred_{fname}" # 保留原图扩展名
|
||
# save_path = os.path.join(save_dir, save_name)
|
||
# Image.fromarray(canvas).save(save_path)
|
||
|
||
|
||
class PatchIndexer:
|
||
def __init__(self, img_dir, patch_size=256, stride=256, save_dir=None):
|
||
self.img_dir = img_dir
|
||
self.patch_size = patch_size
|
||
self.stride = stride
|
||
self.coord_records = []
|
||
|
||
if save_dir is None:
|
||
self.save_dir = os.path.dirname(os.path.abspath(img_dir))
|
||
else:
|
||
self.save_dir = save_dir
|
||
|
||
def index_all(self, save_name=None):
|
||
filelist = sorted(f for f in os.listdir(self.img_dir)
|
||
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')))
|
||
print(f"[INFO] Found {len(filelist)} images.")
|
||
|
||
for fname in tqdm(filelist, desc='Indexing patches'):
|
||
self.index_one(fname)
|
||
|
||
if save_name is None:
|
||
folder_name = os.path.basename(self.save_dir.rstrip('/\\'))
|
||
save_name = f'coords.csv'
|
||
|
||
save_path = os.path.join(self.save_dir, save_name)
|
||
|
||
df = pd.DataFrame(self.coord_records)
|
||
if df.empty:
|
||
print("[⚠️] No patch coordinates were generated!")
|
||
else:
|
||
df.to_csv(save_path, index=False)
|
||
print(f"[✅] Saved {len(df)} patch records to {save_path}")
|
||
|
||
def index_one(self, fname):
|
||
path = os.path.join(self.img_dir, fname)
|
||
name, ext = os.path.splitext(fname)
|
||
ext = ext.lower()
|
||
|
||
try:
|
||
if ext in ['.tif', '.tiff']:
|
||
with rasterio.open(path) as src:
|
||
H, W = src.height, src.width
|
||
else:
|
||
with Image.open(path) as img:
|
||
W, H = img.size # 注意 PIL 是 (W, H)
|
||
except Exception as e:
|
||
print(f"[ERROR] Failed to open {fname}: {e}")
|
||
return
|
||
|
||
ps, st = self.patch_size, self.stride
|
||
|
||
if H < ps or W < ps:
|
||
print(f"[WARNING] Skip {fname}: too small ({W}x{H})")
|
||
return
|
||
|
||
y_list = list(range(0, H - ps + 1, st))
|
||
x_list = list(range(0, W - ps + 1, st))
|
||
|
||
if (H - ps) % st != 0 and (H - ps) not in y_list:
|
||
y_list.append(H - ps)
|
||
if (W - ps) % st != 0 and (W - ps) not in x_list:
|
||
x_list.append(W - ps)
|
||
|
||
for y in y_list:
|
||
for x in x_list:
|
||
self.coord_records.append({
|
||
'orig_name': fname,
|
||
'x': x,
|
||
'y': y,
|
||
'h': ps,
|
||
'w': ps,
|
||
'patch_name': f"{name}_{x}_{y}{ext}"
|
||
})
|