@@ -5932,3 +5932,86 @@ def test_errors_functional(self):
5932
5932
5933
5933
with pytest .raises (ValueError , match = "bouding_boxes must be a tv_tensors.BoundingBoxes instance or a" ):
5934
5934
F .sanitize_bounding_boxes (good_bbox .tolist ())
5935
+
5936
+
5937
+ class TestJPEG :
5938
+ @pytest .mark .parametrize ("quality" , [5 , 75 ])
5939
+ @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
5940
+ def test_kernel_image (self , quality , color_space ):
5941
+ check_kernel (F .jpeg_image , make_image (color_space = color_space ), quality = quality )
5942
+
5943
+ def test_kernel_video (self ):
5944
+ check_kernel (F .jpeg_video , make_video (), quality = 5 )
5945
+
5946
+ @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image , make_video ])
5947
+ def test_functional (self , make_input ):
5948
+ check_functional (F .jpeg , make_input (), quality = 5 )
5949
+
5950
+ @pytest .mark .parametrize (
5951
+ ("kernel" , "input_type" ),
5952
+ [
5953
+ (F .jpeg_image , torch .Tensor ),
5954
+ (F ._jpeg_image_pil , PIL .Image .Image ),
5955
+ (F .jpeg_image , tv_tensors .Image ),
5956
+ (F .jpeg_video , tv_tensors .Video ),
5957
+ ],
5958
+ )
5959
+ def test_functional_signature (self , kernel , input_type ):
5960
+ check_functional_kernel_signature_match (F .jpeg , kernel = kernel , input_type = input_type )
5961
+
5962
+ @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image , make_video ])
5963
+ @pytest .mark .parametrize ("quality" , [5 , (10 , 20 )])
5964
+ @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
5965
+ def test_transform (self , make_input , quality , color_space ):
5966
+ check_transform (transforms .JPEG (quality = quality ), make_input (color_space = color_space ))
5967
+
5968
+ @pytest .mark .parametrize ("quality" , [5 ])
5969
+ def test_functional_image_correctness (self , quality ):
5970
+ image = make_image ()
5971
+
5972
+ actual = F .jpeg (image , quality = quality )
5973
+ expected = F .to_image (F .jpeg (F .to_pil_image (image ), quality = quality ))
5974
+
5975
+ # NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder
5976
+ torch .testing .assert_close (actual , expected , rtol = 0 , atol = 1 )
5977
+
5978
+ @pytest .mark .parametrize ("quality" , [5 , (10 , 20 )])
5979
+ @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
5980
+ @pytest .mark .parametrize ("seed" , list (range (5 )))
5981
+ def test_transform_image_correctness (self , quality , color_space , seed ):
5982
+ image = make_image (color_space = color_space )
5983
+
5984
+ transform = transforms .JPEG (quality = quality )
5985
+
5986
+ with freeze_rng_state ():
5987
+ torch .manual_seed (seed )
5988
+ actual = transform (image )
5989
+
5990
+ torch .manual_seed (seed )
5991
+ expected = F .to_image (transform (F .to_pil_image (image )))
5992
+
5993
+ torch .testing .assert_close (actual , expected , rtol = 0 , atol = 1 )
5994
+
5995
+ @pytest .mark .parametrize ("quality" , [5 , (10 , 20 )])
5996
+ @pytest .mark .parametrize ("seed" , list (range (10 )))
5997
+ def test_transform_get_params_bounds (self , quality , seed ):
5998
+ transform = transforms .JPEG (quality = quality )
5999
+
6000
+ with freeze_rng_state ():
6001
+ torch .manual_seed (seed )
6002
+ params = transform ._get_params ([])
6003
+
6004
+ if isinstance (quality , int ):
6005
+ assert params ["quality" ] == quality
6006
+ else :
6007
+ assert quality [0 ] <= params ["quality" ] <= quality [1 ]
6008
+
6009
+ @pytest .mark .parametrize ("quality" , [[0 ], [0 , 0 , 0 ]])
6010
+ def test_transform_sequence_len_error (self , quality ):
6011
+ with pytest .raises (ValueError , match = "quality should be a sequence of length 2" ):
6012
+ transforms .JPEG (quality = quality )
6013
+
6014
+ @pytest .mark .parametrize ("quality" , [- 1 , 0 , 150 ])
6015
+ def test_transform_invalid_quality_error (self , quality ):
6016
+ with pytest .raises (ValueError , match = "quality must be an integer from 1 to 100" ):
6017
+ transforms .JPEG (quality = quality )
0 commit comments