52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
import albumentations as A
|
|
from albumentations.pytorch import ToTensorV2
|
|
|
|
def get_train_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
|
"""获取训练数据增强和预处理
|
|
|
|
Args:
|
|
mean: 图像均值
|
|
std: 图像标准差
|
|
|
|
Returns:
|
|
Compose: Albumentations变换组合
|
|
"""
|
|
return A.Compose([
|
|
A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0)),
|
|
A.HorizontalFlip(p=0.5),
|
|
A.VerticalFlip(p=0.5),
|
|
A.RandomRotate90(p=0.5),
|
|
A.RandomBrightnessContrast(p=0.2),
|
|
A.Normalize(mean=mean, std=std),
|
|
ToTensorV2(),
|
|
])
|
|
|
|
def get_val_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
|
"""获取验证/测试数据预处理
|
|
|
|
Args:
|
|
mean: 图像均值
|
|
std: 图像标准差
|
|
|
|
Returns:
|
|
Compose: Albumentations变换组合
|
|
"""
|
|
return A.Compose([
|
|
A.Normalize(mean=mean, std=std),
|
|
ToTensorV2(),
|
|
])
|
|
|
|
def get_predict_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
|
"""获取预测数据预处理
|
|
|
|
Args:
|
|
mean: 图像均值
|
|
std: 图像标准差
|
|
|
|
Returns:
|
|
Compose: Albumentations变换组合
|
|
"""
|
|
return A.Compose([
|
|
A.Normalize(mean=mean, std=std),
|
|
ToTensorV2(),
|
|
]) |