ai_project_v1/map/uav_seg/data/transforms.py

52 lines
1.3 KiB
Python
Raw Normal View History

import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_train_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""获取训练数据增强和预处理
Args:
mean: 图像均值
std: 图像标准差
Returns:
Compose: Albumentations变换组合
"""
return A.Compose([
A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.Normalize(mean=mean, std=std),
ToTensorV2(),
])
def get_val_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""获取验证/测试数据预处理
Args:
mean: 图像均值
std: 图像标准差
Returns:
Compose: Albumentations变换组合
"""
return A.Compose([
A.Normalize(mean=mean, std=std),
ToTensorV2(),
])
def get_predict_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""获取预测数据预处理
Args:
mean: 图像均值
std: 图像标准差
Returns:
Compose: Albumentations变换组合
"""
return A.Compose([
A.Normalize(mean=mean, std=std),
ToTensorV2(),
])