209 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pandas as pd
from tqdm import tqdm
import rasterio
from PIL import Image
from collections import defaultdict
import numpy as np
Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告
class PredictionAggregator:
def __init__(self, img_dir=None):
"""
img_dir: 原始图像路径(用于自动获取 full_size
"""
self.preds = defaultdict(list)
self.sizes = dict()
self.img_dir = img_dir
def _get_image_size(self, fname):
"""
自动根据图像格式获取 (H, W)
"""
path = os.path.join(self.img_dir, fname)
ext = os.path.splitext(fname)[1].lower()
if ext in ['.tif', '.tiff']:
with rasterio.open(path) as src:
return src.height, src.width # rasterio: (H, W)
else:
with Image.open(path) as img:
return img.size[1], img.size[0] # PIL: (W, H) → (H, W)
def add_patch(self, fname, x, y, pred_patch, valid_h, valid_w):
"""
添加一个 patch 到对应位置(自动裁去 padding
fname: 原始图像名(如 image_1.png
x, y: patch 左上角坐标
pred_patch: numpy array, shape=[H_pad, W_pad]
valid_h, valid_w: 原始补零前 patch 的真实高度和宽度
"""
if fname not in self.sizes:
if self.img_dir is None:
raise ValueError("img_dir 未指定,无法自动获取图像尺寸")
self.sizes[fname] = self._get_image_size(fname)
# 裁剪有效区域(去掉 padding
patch_crop = pred_patch[:valid_h, :valid_w]
self.preds[fname].append((x, y, patch_crop))
def save_all(self, save_dir):
os.makedirs(save_dir, exist_ok=True)
out_save_path = ""
for fname, patches in self.preds.items():
H, W = self.sizes[fname]
canvas = np.zeros((H, W), dtype=np.uint8)
for x, y, patch in patches:
h, w = patch.shape
canvas[y:y+h, x:x+w] = patch
save_name = f"{fname}" # 保留原图扩展名
save_path = os.path.join(save_dir, save_name)
Image.fromarray(canvas).save(save_path)
out_save_path=save_path
return out_save_path
def save_one_pic(self, save_dir):
os.makedirs(save_dir, exist_ok=True)
out_save_path=""
for fname, patches in self.preds.items():
H, W = self.sizes[fname]
canvas = np.zeros((H, W), dtype=np.uint8)
for x, y, patch in patches:
h, w = patch.shape
canvas[y:y+h, x:x+w] = patch
save_name = f"{fname}" # 保留原图扩展名
save_path = os.path.join(save_dir, save_name)
Image.fromarray(canvas).save(save_path)
out_save_path=save_path
self.preds.clear()
return out_save_path
# class PredictionAggregator:
# def __init__(self, img_dir=None):
# """
# img_dir: 原始图像路径(用于自动获取 full_size
# """
# self.preds = defaultdict(list)
# self.sizes = dict()
# self.img_dir = img_dir
# def _get_image_size(self, fname):
# """
# 自动根据图像格式获取 (H, W)
# """
# path = os.path.join(self.img_dir, fname)
# ext = os.path.splitext(fname)[1].lower()
# if ext in ['.tif', '.tiff']:
# with rasterio.open(path) as src:
# return src.height, src.width # rasterio: (H, W)
# else:
# with Image.open(path) as img:
# return img.size[1], img.size[0] # PIL: (W, H) → (H, W)
# def add_patch(self, fname, x, y, pred_patch):
# """
# 添加一个 patch 到对应位置
# fname: 原始图像名(如 image_1.png
# x, y: patch 左上角坐标
# pred_patch: numpy array, shape=[H, W]
# """
# if fname not in self.sizes:
# if self.img_dir is None:
# raise ValueError("img_dir 未指定,无法自动获取图像尺寸")
# self.sizes[fname] = self._get_image_size(fname)
# self.preds[fname].append((x, y, pred_patch))
# def save_all(self, save_dir):
# os.makedirs(save_dir, exist_ok=True)
# for fname, patches in self.preds.items():
# H, W = self.sizes[fname]
# canvas = np.zeros((H, W), dtype=np.uint8)
# for x, y, patch in patches:
# h, w = patch.shape
# canvas[y:y+h, x:x+w] = patch
# save_name = f"pred_{fname}" # 保留原图扩展名
# save_path = os.path.join(save_dir, save_name)
# Image.fromarray(canvas).save(save_path)
class PatchIndexer:
def __init__(self, img_dir, patch_size=256, stride=256, save_dir=None):
self.img_dir = img_dir
self.patch_size = patch_size
self.stride = stride
self.coord_records = []
if save_dir is None:
self.save_dir = os.path.dirname(os.path.abspath(img_dir))
else:
self.save_dir = save_dir
def index_all(self, save_name=None):
filelist = sorted(f for f in os.listdir(self.img_dir)
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')))
print(f"[INFO] Found {len(filelist)} images.")
for fname in tqdm(filelist, desc='Indexing patches'):
self.index_one(fname)
if save_name is None:
folder_name = os.path.basename(self.save_dir.rstrip('/\\'))
save_name = f'coords.csv'
save_path = os.path.join(self.save_dir, save_name)
df = pd.DataFrame(self.coord_records)
if df.empty:
print("[⚠️] No patch coordinates were generated!")
else:
df.to_csv(save_path, index=False)
print(f"[✅] Saved {len(df)} patch records to {save_path}")
def index_one(self, fname):
path = os.path.join(self.img_dir, fname)
name, ext = os.path.splitext(fname)
ext = ext.lower()
try:
if ext in ['.tif', '.tiff']:
with rasterio.open(path) as src:
H, W = src.height, src.width
else:
with Image.open(path) as img:
W, H = img.size # 注意 PIL 是 (W, H)
except Exception as e:
print(f"[ERROR] Failed to open {fname}: {e}")
return
ps, st = self.patch_size, self.stride
if H < ps or W < ps:
print(f"[WARNING] Skip {fname}: too small ({W}x{H})")
return
y_list = list(range(0, H - ps + 1, st))
x_list = list(range(0, W - ps + 1, st))
if (H - ps) % st != 0 and (H - ps) not in y_list:
y_list.append(H - ps)
if (W - ps) % st != 0 and (W - ps) not in x_list:
x_list.append(W - ps)
for y in y_list:
for x in x_list:
self.coord_records.append({
'orig_name': fname,
'x': x,
'y': y,
'h': ps,
'w': ps,
'patch_name': f"{name}_{x}_{y}{ext}"
})