diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 9e9144c601..43d57b7d12 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -17,9 +17,11 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 - -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, +) def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): @@ -42,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) - if is_cuda_8_9: + if is_sm_at_least_89(): base_functions.append(float8_weight_only()) return base_functions diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 74c130dc5e..4d8312b427 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -37,13 +37,14 @@ MappingType, choose_qparams_affine, ) +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features): @@ -59,12 +60,14 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] ) # Inputs are (M,..), K, N @common_utils.parametrize( @@ -134,12 +137,16 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -147,7 +154,9 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_unsupported_granularity(self): class UnsupportedGranularity: pass @@ -158,7 +167,9 @@ class UnsupportedGranularity: ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_per_row_with_float32(self): with pytest.raises( AssertionError, @@ -170,7 +181,9 @@ def test_per_row_with_float32(self): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model @@ -240,7 +253,9 @@ def test_serialization(self, mode: str): ), f"Scales do not match for {layer_name}" @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 245abe0d02..f61ff3738f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,11 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -60,10 +64,6 @@ torch.manual_seed(0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -219,7 +219,7 @@ def test_axiswise_reshape(self): ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -333,7 +333,9 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", @@ -415,7 +417,9 @@ def test_linear_from_recipe( config, ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -462,7 +466,9 @@ def test_autocast_outputs( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) @@ -523,7 +529,7 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s - @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -531,7 +537,7 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): m(x) - @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") def test_quantize(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -549,7 +555,7 @@ def test_quantize(self): class TestScaledMM: @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -594,7 +600,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_cuda_8_9, "CUDA not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") @@ -630,7 +636,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ced5db7ff3..6d21686e32 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -46,10 +50,6 @@ from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config -# TODO(future PR): standardize IS_H100 with the rest of the codebase -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - def _test_compile_base( backend: str, @@ -99,7 +99,7 @@ def _test_compile_base( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -126,7 +126,7 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @@ -177,7 +177,7 @@ def test_aot_eager( [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -215,7 +215,9 @@ def test_inductor_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) -@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") +@unittest.skipIf( + not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" +) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() config = recipe_name_to_linear_config(recipe_name) @@ -253,7 +255,7 @@ def forward(self, x): # TODO(future): figure out why the test below fails on CUDA capability 8.9 @unittest.skipIf( - not torch.cuda.is_available() or not is_H100, + not torch.cuda.is_available() or not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): @@ -269,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self): torch.testing.assert_close(y_eager, y_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -293,7 +295,7 @@ def to_float(x): torch.testing.assert_close(y2_eager, y2_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): @@ -323,7 +325,7 @@ def test_float8_graph_output(self): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func(): @@ -364,7 +366,7 @@ def __exit__(self, *args): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func_cuda_graph_success(): @@ -396,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success(): @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index c3e31816ad..fbe5c9b508 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -6,7 +6,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -40,8 +40,7 @@ from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py index d5c0d7b853..d2e9a51c7f 100644 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py @@ -3,7 +3,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -30,8 +30,7 @@ from torchao.float8.float8_tensor import GemmInputRole from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index e9028c8712..311964d831 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -34,9 +38,6 @@ from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - torch.manual_seed(0) @@ -176,7 +177,9 @@ def _test_impl(self, config: Float8LinearConfig) -> None: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_89(), reason="requires SM89 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( self, @@ -199,7 +202,9 @@ def test_encoder_fw_bw_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) - @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_90(), reason="requires SM90 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_recipe( self, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index df20c5f03b..10f2d157f9 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -91,7 +91,8 @@ TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, is_fbcode, - benchmark_model + benchmark_model, + is_sm_at_least_90, ) from torchao.dtypes.utils import is_device @@ -105,7 +106,6 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -779,7 +779,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype @@ -799,7 +799,7 @@ def test_autoquantizable_flatten_unflatten(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): @@ -813,7 +813,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 4ed0974172..3e8c9b0a04 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -13,10 +13,10 @@ import pytest import torch from parameterized import parameterized +from torchao.utils import is_sm_at_least_90 logging.basicConfig(level=logging.INFO) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_H100, "Needs H100") + @unittest.skipIf(not is_sm_at_least_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bc9b02deb5..4cac940313 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -20,11 +20,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -102,7 +99,7 @@ def test_linear_compile(elem_dtype, bias): Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") input_shape = (2, 4) grad_shape = (2, 6) @@ -173,7 +170,7 @@ def test_inference_compile_simple(elem_dtype): Smoke test for inference compile """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 964a575411..522785ae6f 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -24,11 +24,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -225,7 +222,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 60a7341e39..96ccb1889c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -53,8 +53,8 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_MI300, - is_sm_89, - is_sm_90, + is_sm_at_least_89, + is_sm_at_least_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -864,11 +864,11 @@ def _normalize_granularity( for _granularity in processed_granularity: if isinstance(_granularity, PerTensor): assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" elif isinstance(_granularity, PerRow): assert ( - is_sm_90() or is_MI300() + is_sm_at_least_90() or is_MI300() ), "PerRow quantization only works for CUDA>=9.0 and MI300+" else: raise ValueError(f"Invalid granularity type: {_granularity}") @@ -966,7 +966,7 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -1023,7 +1023,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index ba91fb3fe0..d56191ed6b 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -33,8 +33,8 @@ "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", "is_MI300", - "is_sm_89", - "is_sm_90", + "is_sm_at_least_89", + "is_sm_at_least_90", ] @@ -612,7 +612,7 @@ def is_MI300(): return False -def is_sm_89(): +def is_sm_at_least_89(): return ( torch.cuda.is_available() and torch.version.cuda @@ -620,7 +620,7 @@ def is_sm_89(): ) -def is_sm_90(): +def is_sm_at_least_90(): return ( torch.cuda.is_available() and torch.version.cuda