@@ -662,6 +662,8 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
662
662
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
663
663
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
664
664
def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
665
+ if device == "cpu" :
666
+ self .skipTest (f"Temporarily skipping for { device } " )
665
667
if dtype != torch .bfloat16 :
666
668
self .skipTest ("Currently only supports bfloat16." )
667
669
for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 8 )] if device == 'cuda' else [])):
@@ -673,6 +675,8 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
673
675
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
674
676
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
675
677
def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
678
+ if device == "cpu" :
679
+ self .skipTest (f"Temporarily skipping for { device } " )
676
680
if dtype != torch .bfloat16 :
677
681
self .skipTest ("Currently only supports bfloat16." )
678
682
m_shapes = [16 , 256 ] + ([1 ] if device == "cuda" else [])
@@ -815,6 +819,8 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
815
819
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
816
820
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
817
821
def test_int4_weight_only_quant_subclass (self , device , dtype ):
822
+ if device == "cpu" :
823
+ self .skipTest (f"Temporarily skipping for { device } " )
818
824
if dtype != torch .bfloat16 :
819
825
self .skipTest (f"Fails for { dtype } " )
820
826
for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 8 )] if device == 'cuda' else [])):
@@ -908,6 +914,8 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
908
914
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
909
915
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
910
916
def test_int4_weight_only_quant_subclass_api (self , device , dtype ):
917
+ if device == "cpu" :
918
+ self .skipTest (f"Temporarily skipping for { device } " )
911
919
if dtype != torch .bfloat16 :
912
920
self .skipTest (f"Fails for { dtype } " )
913
921
for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
@@ -923,6 +931,8 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
923
931
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
924
932
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
925
933
def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
934
+ if device == "cpu" :
935
+ self .skipTest (f"Temporarily skipping for { device } " )
926
936
if dtype != torch .bfloat16 :
927
937
self .skipTest (f"Fails for { dtype } " )
928
938
for test_shape in ([(256 , 256 , 16 )] + ([(256 , 256 , 8 )] if device == 'cuda' else [])):
0 commit comments