import os import pandas as pd from tqdm import tqdm import rasterio from PIL import Image from collections import defaultdict import numpy as np Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告 class PredictionAggregator: def __init__(self, img_dir=None): """ img_dir: 原始图像路径(用于自动获取 full_size) """ self.preds = defaultdict(list) self.sizes = dict() self.img_dir = img_dir def _get_image_size(self, fname): """ 自动根据图像格式获取 (H, W) """ path = os.path.join(self.img_dir, fname) ext = os.path.splitext(fname)[1].lower() if ext in ['.tif', '.tiff']: with rasterio.open(path) as src: return src.height, src.width # rasterio: (H, W) else: with Image.open(path) as img: return img.size[1], img.size[0] # PIL: (W, H) → (H, W) def add_patch(self, fname, x, y, pred_patch, valid_h, valid_w): """ 添加一个 patch 到对应位置(自动裁去 padding) fname: 原始图像名(如 image_1.png) x, y: patch 左上角坐标 pred_patch: numpy array, shape=[H_pad, W_pad] valid_h, valid_w: 原始补零前 patch 的真实高度和宽度 """ if fname not in self.sizes: if self.img_dir is None: raise ValueError("img_dir 未指定,无法自动获取图像尺寸") self.sizes[fname] = self._get_image_size(fname) # 裁剪有效区域(去掉 padding) patch_crop = pred_patch[:valid_h, :valid_w] self.preds[fname].append((x, y, patch_crop)) def save_all(self, save_dir): os.makedirs(save_dir, exist_ok=True) out_save_path = "" for fname, patches in self.preds.items(): H, W = self.sizes[fname] canvas = np.zeros((H, W), dtype=np.uint8) for x, y, patch in patches: h, w = patch.shape canvas[y:y+h, x:x+w] = patch save_name = f"{fname}" # 保留原图扩展名 save_path = os.path.join(save_dir, save_name) Image.fromarray(canvas).save(save_path) out_save_path=save_path return out_save_path def save_one_pic(self, save_dir): os.makedirs(save_dir, exist_ok=True) out_save_path="" for fname, patches in self.preds.items(): H, W = self.sizes[fname] canvas = np.zeros((H, W), dtype=np.uint8) for x, y, patch in patches: h, w = patch.shape canvas[y:y+h, x:x+w] = patch save_name = f"{fname}" # 保留原图扩展名 save_path = os.path.join(save_dir, save_name) Image.fromarray(canvas).save(save_path) out_save_path=save_path self.preds.clear() return out_save_path # class PredictionAggregator: # def __init__(self, img_dir=None): # """ # img_dir: 原始图像路径(用于自动获取 full_size) # """ # self.preds = defaultdict(list) # self.sizes = dict() # self.img_dir = img_dir # def _get_image_size(self, fname): # """ # 自动根据图像格式获取 (H, W) # """ # path = os.path.join(self.img_dir, fname) # ext = os.path.splitext(fname)[1].lower() # if ext in ['.tif', '.tiff']: # with rasterio.open(path) as src: # return src.height, src.width # rasterio: (H, W) # else: # with Image.open(path) as img: # return img.size[1], img.size[0] # PIL: (W, H) → (H, W) # def add_patch(self, fname, x, y, pred_patch): # """ # 添加一个 patch 到对应位置 # fname: 原始图像名(如 image_1.png) # x, y: patch 左上角坐标 # pred_patch: numpy array, shape=[H, W] # """ # if fname not in self.sizes: # if self.img_dir is None: # raise ValueError("img_dir 未指定,无法自动获取图像尺寸") # self.sizes[fname] = self._get_image_size(fname) # self.preds[fname].append((x, y, pred_patch)) # def save_all(self, save_dir): # os.makedirs(save_dir, exist_ok=True) # for fname, patches in self.preds.items(): # H, W = self.sizes[fname] # canvas = np.zeros((H, W), dtype=np.uint8) # for x, y, patch in patches: # h, w = patch.shape # canvas[y:y+h, x:x+w] = patch # save_name = f"pred_{fname}" # 保留原图扩展名 # save_path = os.path.join(save_dir, save_name) # Image.fromarray(canvas).save(save_path) class PatchIndexer: def __init__(self, img_dir, patch_size=256, stride=256, save_dir=None): self.img_dir = img_dir self.patch_size = patch_size self.stride = stride self.coord_records = [] if save_dir is None: self.save_dir = os.path.dirname(os.path.abspath(img_dir)) else: self.save_dir = save_dir def index_all(self, save_name=None): filelist = sorted(f for f in os.listdir(self.img_dir) if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))) print(f"[INFO] Found {len(filelist)} images.") for fname in tqdm(filelist, desc='Indexing patches'): self.index_one(fname) if save_name is None: folder_name = os.path.basename(self.save_dir.rstrip('/\\')) save_name = f'coords.csv' save_path = os.path.join(self.save_dir, save_name) df = pd.DataFrame(self.coord_records) if df.empty: print("[⚠️] No patch coordinates were generated!") else: df.to_csv(save_path, index=False) print(f"[✅] Saved {len(df)} patch records to {save_path}") def index_one(self, fname): path = os.path.join(self.img_dir, fname) name, ext = os.path.splitext(fname) ext = ext.lower() try: if ext in ['.tif', '.tiff']: with rasterio.open(path) as src: H, W = src.height, src.width else: with Image.open(path) as img: W, H = img.size # 注意 PIL 是 (W, H) except Exception as e: print(f"[ERROR] Failed to open {fname}: {e}") return ps, st = self.patch_size, self.stride if H < ps or W < ps: print(f"[WARNING] Skip {fname}: too small ({W}x{H})") return y_list = list(range(0, H - ps + 1, st)) x_list = list(range(0, W - ps + 1, st)) if (H - ps) % st != 0 and (H - ps) not in y_list: y_list.append(H - ps) if (W - ps) % st != 0 and (W - ps) not in x_list: x_list.append(W - ps) for y in y_list: for x in x_list: self.coord_records.append({ 'orig_name': fname, 'x': x, 'y': y, 'h': ps, 'w': ps, 'patch_name': f"{name}_{x}_{y}{ext}" })