11import dataclasses
2- from typing import Callable , Dict , Type
2+ from typing import Callable , Dict , Sequence , Type
33
44import pytest
55import torchvision .prototype .transforms .functional as F
6- from prototype_transforms_kernel_infos import KERNEL_INFOS
6+ from prototype_transforms_kernel_infos import KERNEL_INFOS , Skip
77from torchvision .prototype import features
88
99__all__ = ["DispatcherInfo" , "DISPATCHER_INFOS" ]
1515class DispatcherInfo :
1616 dispatcher : Callable
1717 kernels : Dict [Type , Callable ]
18+ skips : Sequence [Skip ] = dataclasses .field (default_factory = list )
19+ _skips_map : Dict [str , Skip ] = dataclasses .field (default = None , init = False )
20+
21+ def __post_init__ (self ):
22+ self ._skips_map = {skip .test_name : skip for skip in self .skips }
1823
1924 def sample_inputs (self , * types ):
2025 for type in types or self .kernels .keys ():
@@ -23,6 +28,11 @@ def sample_inputs(self, *types):
2328
2429 yield from KERNEL_SAMPLE_INPUTS_FN_MAP [self .kernels [type ]]()
2530
31+ def maybe_skip (self , * , test_name , args_kwargs , device ):
32+ skip = self ._skips_map .get (test_name )
33+ if skip and skip .condition (args_kwargs , device ):
34+ pytest .skip (skip .reason )
35+
2636
2737DISPATCHER_INFOS = [
2838 DispatcherInfo (
@@ -97,6 +107,14 @@ def sample_inputs(self, *types):
97107 features .Mask : F .perspective_mask ,
98108 },
99109 ),
110+ DispatcherInfo (
111+ F .elastic ,
112+ kernels = {
113+ features .Image : F .elastic_image_tensor ,
114+ features .BoundingBox : F .elastic_bounding_box ,
115+ features .Mask : F .elastic_mask ,
116+ },
117+ ),
100118 DispatcherInfo (
101119 F .center_crop ,
102120 kernels = {
@@ -153,4 +171,66 @@ def sample_inputs(self, *types):
153171 features .Image : F .erase_image_tensor ,
154172 },
155173 ),
174+ DispatcherInfo (
175+ F .adjust_brightness ,
176+ kernels = {
177+ features .Image : F .adjust_brightness_image_tensor ,
178+ },
179+ ),
180+ DispatcherInfo (
181+ F .adjust_contrast ,
182+ kernels = {
183+ features .Image : F .adjust_contrast_image_tensor ,
184+ },
185+ ),
186+ DispatcherInfo (
187+ F .adjust_gamma ,
188+ kernels = {
189+ features .Image : F .adjust_gamma_image_tensor ,
190+ },
191+ ),
192+ DispatcherInfo (
193+ F .adjust_hue ,
194+ kernels = {
195+ features .Image : F .adjust_hue_image_tensor ,
196+ },
197+ ),
198+ DispatcherInfo (
199+ F .adjust_saturation ,
200+ kernels = {
201+ features .Image : F .adjust_saturation_image_tensor ,
202+ },
203+ ),
204+ DispatcherInfo (
205+ F .five_crop ,
206+ kernels = {
207+ features .Image : F .five_crop_image_tensor ,
208+ },
209+ skips = [
210+ Skip (
211+ "test_scripted_smoke" ,
212+ condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs ["size" ], int ),
213+ reason = "Integer size is not supported when scripting five_crop_image_tensor." ,
214+ ),
215+ ],
216+ ),
217+ DispatcherInfo (
218+ F .ten_crop ,
219+ kernels = {
220+ features .Image : F .ten_crop_image_tensor ,
221+ },
222+ skips = [
223+ Skip (
224+ "test_scripted_smoke" ,
225+ condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs ["size" ], int ),
226+ reason = "Integer size is not supported when scripting ten_crop_image_tensor." ,
227+ ),
228+ ],
229+ ),
230+ DispatcherInfo (
231+ F .normalize ,
232+ kernels = {
233+ features .Image : F .normalize_image_tensor ,
234+ },
235+ ),
156236]
0 commit comments