@@ -155,12 +155,13 @@ class FiveCrop(Transform):
155155 """
156156 Example:
157157 >>> class BatchMultiCrop(transforms.Transform):
158- ... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]):
159- ... images, labels = sample
160- ... batch_size = len(images)
161- ... images = features.Image.wrap_like(images[0], torch.stack(images))
158+ ... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]):
159+ ... images_or_videos, labels = sample
160+ ... batch_size = len(images_or_videos)
161+ ... image_or_video = images_or_videos[0]
162+ ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
162163 ... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
163- ... return images , labels
164+ ... return images_or_videos , labels
164165 ...
165166 >>> image = features.Image(torch.rand(3, 256, 256))
166167 >>> label = features.Label(0)
@@ -172,15 +173,21 @@ class FiveCrop(Transform):
172173 torch.Size([5])
173174 """
174175
175- _transformed_types = (features .Image , PIL .Image .Image , features .is_simple_tensor )
176+ _transformed_types = (features .Image , PIL .Image .Image , features .is_simple_tensor , features . Video )
176177
177178 def __init__ (self , size : Union [int , Sequence [int ]]) -> None :
178179 super ().__init__ ()
179180 self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
180181
181182 def _transform (
182- self , inpt : features .ImageType , params : Dict [str , Any ]
183- ) -> Tuple [features .ImageType , features .ImageType , features .ImageType , features .ImageType , features .ImageType ]:
183+ self , inpt : features .ImageOrVideoType , params : Dict [str , Any ]
184+ ) -> Tuple [
185+ features .ImageOrVideoType ,
186+ features .ImageOrVideoType ,
187+ features .ImageOrVideoType ,
188+ features .ImageOrVideoType ,
189+ features .ImageOrVideoType ,
190+ ]:
184191 return F .five_crop (inpt , self .size )
185192
186193 def forward (self , * inputs : Any ) -> Any :
@@ -194,14 +201,14 @@ class TenCrop(Transform):
194201 See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
195202 """
196203
197- _transformed_types = (features .Image , PIL .Image .Image , features .is_simple_tensor )
204+ _transformed_types = (features .Image , PIL .Image .Image , features .is_simple_tensor , features . Video )
198205
199206 def __init__ (self , size : Union [int , Sequence [int ]], vertical_flip : bool = False ) -> None :
200207 super ().__init__ ()
201208 self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
202209 self .vertical_flip = vertical_flip
203210
204- def _transform (self , inpt : features .ImageType , params : Dict [str , Any ]) -> List [features .ImageType ]:
211+ def _transform (self , inpt : features .ImageOrVideoType , params : Dict [str , Any ]) -> List [features .ImageOrVideoType ]:
205212 return F .ten_crop (inpt , self .size , vertical_flip = self .vertical_flip )
206213
207214 def forward (self , * inputs : Any ) -> Any :
0 commit comments