@@ -354,24 +354,34 @@ def __init__(
354
354
v2_transforms .ElasticTransform ,
355
355
legacy_transforms .ElasticTransform ,
356
356
[
357
- ArgsKwargs (),
358
357
ArgsKwargs (alpha = 20.0 ),
359
358
ArgsKwargs (alpha = (15.3 , 27.2 )),
360
359
ArgsKwargs (sigma = 3.0 ),
361
360
ArgsKwargs (sigma = (2.5 , 3.9 )),
362
361
ArgsKwargs (interpolation = v2_transforms .InterpolationMode .NEAREST ),
363
- ArgsKwargs (interpolation = v2_transforms .InterpolationMode .BICUBIC ),
364
362
ArgsKwargs (interpolation = PIL .Image .NEAREST ),
365
- ArgsKwargs (interpolation = PIL .Image .BICUBIC ),
366
363
ArgsKwargs (fill = 1 ),
364
+ * extra_args_kwargs ,
367
365
],
368
366
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
369
367
make_images_kwargs = dict (DEFAULT_MAKE_IMAGES_KWARGS , sizes = [(163 , 163 ), (72 , 333 ), (313 , 95 )], dtypes = [dt ]),
370
368
# We updated gaussian blur kernel generation with a faster and numerically more stable version
371
369
# This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
372
370
closeness_kwargs = ckw ,
373
371
)
374
- for dt , ckw in [(torch .uint8 , {"rtol" : 1e-1 , "atol" : 1 }), (torch .float32 , {"rtol" : 1e-2 , "atol" : 1e-3 })]
372
+ for dt , ckw , extra_args_kwargs in [
373
+ (torch .uint8 , {"rtol" : 1e-1 , "atol" : 1 }, []),
374
+ (
375
+ torch .float32 ,
376
+ {"rtol" : 1e-2 , "atol" : 1e-3 },
377
+ [
378
+ # These proved to be flaky on uint8 inputs so we only run them on float32
379
+ ArgsKwargs (),
380
+ ArgsKwargs (interpolation = v2_transforms .InterpolationMode .BICUBIC ),
381
+ ArgsKwargs (interpolation = PIL .Image .BICUBIC ),
382
+ ],
383
+ ),
384
+ ]
375
385
],
376
386
ConsistencyConfig (
377
387
v2_transforms .GaussianBlur ,
0 commit comments