diff --git a/torchvision/models/video/s3d.py b/torchvision/models/video/s3d.py index f80d849683c..90861e57191 100644 --- a/torchvision/models/video/s3d.py +++ b/torchvision/models/video/s3d.py @@ -160,6 +160,7 @@ class S3D_Weights(WeightsEnum): resize_size=(256, 256), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + channel_order=(2, 1, 0), # RGB to BGR ), meta={ "min_size": (224, 224), diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 33b94d01c9d..1b41f68d672 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -89,6 +89,7 @@ def __init__( mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + channel_order: Optional[Tuple[int, int, int]] = None, ) -> None: super().__init__() self.crop_size = list(crop_size) @@ -96,6 +97,7 @@ def __init__( self.mean = list(mean) self.std = list(std) self.interpolation = interpolation + self.channel_order = channel_order def forward(self, vid: Tensor) -> Tensor: need_squeeze = False @@ -109,6 +111,8 @@ def forward(self, vid: Tensor) -> Tensor: vid = F.center_crop(vid, self.crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self.mean, std=self.std) + if self.channel_order is not None: + vid = vid[:, self.channel_order] H, W = self.crop_size vid = vid.view(N, T, C, H, W) vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) @@ -124,17 +128,21 @@ def __repr__(self) -> str: format_string += f"\n mean={self.mean}" format_string += f"\n std={self.std}" format_string += f"\n interpolation={self.interpolation}" + format_string += f"\n channel_order={self.channel_order}" format_string += "\n)" return format_string def describe(self) -> str: - return ( + s = ( "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " - f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output " - "dimensions are permuted to ``(..., C, T, H, W)`` tensors." + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. " ) + if self.channel_order is not None: + s += f"Remaps the order within the channels dimension using ``channel_order={self.channel_order}``. " + s += "Finally the output dimensions are permuted to ``(..., C, T, H, W)`` tensors." + return s class SemanticSegmentation(nn.Module):