@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
44
44
img = F .equalize (img )
45
45
elif op_name == "Invert" :
46
46
img = F .invert (img )
47
+ elif op_name == "Identity" :
48
+ pass
47
49
else :
48
50
raise ValueError ("The provided operator {} is not recognized." .format (op_name ))
49
51
return img
@@ -353,6 +355,7 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod
353
355
def _augmentation_space (self , num_bins : int ) -> Dict [str , Tuple [Tensor , bool ]]:
354
356
return {
355
357
# op_name: (magnitudes, signed)
358
+ "Identity" : (torch .tensor (0.0 ), False ),
356
359
"ShearX" : (torch .linspace (0.0 , 0.99 , num_bins ), True ),
357
360
"ShearY" : (torch .linspace (0.0 , 0.99 , num_bins ), True ),
358
361
"TranslateX" : (torch .linspace (0.0 , 32.0 , num_bins ), True ),
@@ -366,7 +369,6 @@ def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
366
369
"Solarize" : (torch .linspace (256.0 , 0.0 , num_bins ), False ),
367
370
"AutoContrast" : (torch .tensor (0.0 ), False ),
368
371
"Equalize" : (torch .tensor (0.0 ), False ),
369
- "Invert" : (torch .tensor (0.0 ), False ),
370
372
}
371
373
372
374
def forward (self , img : Tensor ) -> Tensor :
0 commit comments