import albumentations as A from albumentations.pytorch import ToTensorV2 def get_train_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): """获取训练数据增强和预处理 Args: mean: 图像均值 std: 图像标准差 Returns: Compose: Albumentations变换组合 """ return A.Compose([ # A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0)), A.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0)), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5), A.RandomBrightnessContrast(p=0.2), A.Normalize(mean=mean, std=std), ToTensorV2(), ]) def get_val_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): """获取验证/测试数据预处理 Args: mean: 图像均值 std: 图像标准差 Returns: Compose: Albumentations变换组合 """ return A.Compose([ A.Normalize(mean=mean, std=std), ToTensorV2(), ]) def get_predict_transforms(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): """获取预测数据预处理 Args: mean: 图像均值 std: 图像标准差 Returns: Compose: Albumentations变换组合 """ return A.Compose([ A.Normalize(mean=mean, std=std), ToTensorV2(), ])