@@ -1051,38 +1051,35 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10511051 return value
10521052
10531053 @staticmethod
1054- @torch .jit .unused
1055- def get_params (brightness , contrast , saturation , hue ):
1056- """Get a randomized transform to be applied on image.
1054+ def get_params (brightness : Optional [List [float ]],
1055+ contrast : Optional [List [float ]],
1056+ saturation : Optional [List [float ]],
1057+ hue : Optional [List [float ]]
1058+ ) -> Tuple [Tensor , Optional [float ], Optional [float ], Optional [float ], Optional [float ]]:
1059+ """Get the parameters for the randomized transform to be applied on image.
10571060
1058- Arguments are same as that of __init__.
1061+ Args:
1062+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
1063+ uniformly. Pass None to turn off the transformation.
1064+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
1065+ uniformly. Pass None to turn off the transformation.
1066+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
1067+ uniformly. Pass None to turn off the transformation.
1068+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
1069+ Pass None to turn off the transformation.
10591070
10601071 Returns:
1061- Transform which randomly adjusts brightness, contrast and
1062- saturation in a random order.
1072+ tuple: The parameters used to apply the randomized transform
1073+ along with their random order.
10631074 """
1064- transforms = []
1065-
1066- if brightness is not None :
1067- brightness_factor = random .uniform (brightness [0 ], brightness [1 ])
1068- transforms .append (Lambda (lambda img : F .adjust_brightness (img , brightness_factor )))
1069-
1070- if contrast is not None :
1071- contrast_factor = random .uniform (contrast [0 ], contrast [1 ])
1072- transforms .append (Lambda (lambda img : F .adjust_contrast (img , contrast_factor )))
1073-
1074- if saturation is not None :
1075- saturation_factor = random .uniform (saturation [0 ], saturation [1 ])
1076- transforms .append (Lambda (lambda img : F .adjust_saturation (img , saturation_factor )))
1077-
1078- if hue is not None :
1079- hue_factor = random .uniform (hue [0 ], hue [1 ])
1080- transforms .append (Lambda (lambda img : F .adjust_hue (img , hue_factor )))
1075+ fn_idx = torch .randperm (4 )
10811076
1082- random .shuffle (transforms )
1083- transform = Compose (transforms )
1077+ b = None if brightness is None else float (torch .empty (1 ).uniform_ (brightness [0 ], brightness [1 ]))
1078+ c = None if contrast is None else float (torch .empty (1 ).uniform_ (contrast [0 ], contrast [1 ]))
1079+ s = None if saturation is None else float (torch .empty (1 ).uniform_ (saturation [0 ], saturation [1 ]))
1080+ h = None if hue is None else float (torch .empty (1 ).uniform_ (hue [0 ], hue [1 ]))
10841081
1085- return transform
1082+ return fn_idx , b , c , s , h
10861083
10871084 def forward (self , img ):
10881085 """
@@ -1092,26 +1089,17 @@ def forward(self, img):
10921089 Returns:
10931090 PIL Image or Tensor: Color jittered image.
10941091 """
1095- fn_idx = torch .randperm (4 )
1092+ fn_idx , brightness_factor , contrast_factor , saturation_factor , hue_factor = \
1093+ self .get_params (self .brightness , self .contrast , self .saturation , self .hue )
1094+
10961095 for fn_id in fn_idx :
1097- if fn_id == 0 and self .brightness is not None :
1098- brightness = self .brightness
1099- brightness_factor = torch .tensor (1.0 ).uniform_ (brightness [0 ], brightness [1 ]).item ()
1096+ if fn_id == 0 and brightness_factor is not None :
11001097 img = F .adjust_brightness (img , brightness_factor )
1101-
1102- if fn_id == 1 and self .contrast is not None :
1103- contrast = self .contrast
1104- contrast_factor = torch .tensor (1.0 ).uniform_ (contrast [0 ], contrast [1 ]).item ()
1098+ elif fn_id == 1 and contrast_factor is not None :
11051099 img = F .adjust_contrast (img , contrast_factor )
1106-
1107- if fn_id == 2 and self .saturation is not None :
1108- saturation = self .saturation
1109- saturation_factor = torch .tensor (1.0 ).uniform_ (saturation [0 ], saturation [1 ]).item ()
1100+ elif fn_id == 2 and saturation_factor is not None :
11101101 img = F .adjust_saturation (img , saturation_factor )
1111-
1112- if fn_id == 3 and self .hue is not None :
1113- hue = self .hue
1114- hue_factor = torch .tensor (1.0 ).uniform_ (hue [0 ], hue [1 ]).item ()
1102+ elif fn_id == 3 and hue_factor is not None :
11151103 img = F .adjust_hue (img , hue_factor )
11161104
11171105 return img
0 commit comments