We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f45d714 commit 6c7da47Copy full SHA for 6c7da47
torchvision/transforms/_presets.py
@@ -113,13 +113,15 @@ def forward(self, vid: Tensor) -> Tensor:
113
if self.crop_size is not None:
114
if self.crop_mode == "center_crop":
115
vid = F.center_crop(vid, self.crop_size)
116
- else:
+ elif vid.shape[-2] >= self.crop_size[-2] and vid.shape[-1] >= self.crop_size[-1]:
117
crops = (
118
list(F.five_crop(vid, self.crop_size))
119
if self.crop_mode == "five_crop"
120
else list(F.ten_crop(vid, self.crop_size))
121
)
122
vid = torch.cat(crops)
123
+ else:
124
+ vid = F.resize(vid, self.crop_size, interpolation=self.interpolation)
125
vid = F.convert_image_dtype(vid, torch.float)
126
vid = F.normalize(vid, mean=self.mean, std=self.std)
127
H, W = vid.shape[-2:]
0 commit comments