diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..88e133ccf8 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -2,6 +2,7 @@ import unittest import torch +from test_utils import skip_if_rocm from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, @@ -89,6 +90,7 @@ def test_tensor_core_layout_transpose(self): aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) + @skip_if_rocm("ROCm development in progress") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( "apply_quant", get_quantization_functions(True, True, "cuda", True) @@ -168,6 +170,7 @@ def apply_uint6_weight_only_quant(linear): deregister_aqt_quantized_linear_dispatch(dispatch_condition) + @skip_if_rocm("ROCm development in progress") @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): @@ -180,6 +183,7 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] + @skip_if_rocm("ROCm development in progress") @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_flatten_unflatten(self, device, dtype): diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 8bb39b2cc8..ea30edfe38 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -2,6 +2,7 @@ import unittest import torch +from test_utils import skip_if_rocm from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -108,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) + @skip_if_rocm("ROCm development in progress") @unittest.skipIf(is_fbcode(), reason="broken in fbcode") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..c20920fb9f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -24,6 +24,8 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +from test_utils import skip_if_rocm + from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -423,6 +425,7 @@ def test_linear_from_config_params( @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skip_if_rocm("ROCm development in progress") def test_linear_from_recipe( self, recipe_name, diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..4c85ee2c30 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,6 +1,7 @@ import unittest import torch +from test_utils import skip_if_rocm from torchao.quantization import ( MappingType, @@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self): ref_dot_product_error=0.000704, ) + @skip_if_rocm("ROCm development in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3d51ed048f..53b2d3be22 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -90,6 +90,8 @@ except ModuleNotFoundError: has_gemlite = False +from test_utils import skip_if_rocm + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -566,6 +568,7 @@ def test_per_token_linear_cpu(self): self._test_per_token_linear_impl("cpu", dtype) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) @@ -684,6 +687,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -703,6 +707,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -896,6 +901,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -915,6 +921,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index bab65fc2fb..d7f8102f9f 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -8,6 +8,7 @@ import torch from galore_test_utils import make_data +from test_utils import skip_if_rocm from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher @@ -29,6 +30,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) +@skip_if_rocm("ROCm development in progress") def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 1b91983bc0..3843d0e0cd 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -10,6 +10,8 @@ if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ +from test_utils import skip_if_rocm + class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -113,6 +115,7 @@ def test_awq_loading(device, qdtype): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_rocm("ROCm development in progress") def test_save_weights_only(): dataset_size = 100 l1, l2, l3 = 512, 256, 128 diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index acc7576e56..8f5dccdac5 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -42,6 +42,7 @@ except ImportError: lpmm = None +from test_utils import skip_if_rocm _DEVICES = get_available_devices() @@ -112,6 +113,7 @@ class TestOptim(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) + @skip_if_rocm("ROCm development in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": if not TORCH_VERSION_AT_LEAST_2_4: diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index 48793ba907..cd90408644 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -13,6 +13,8 @@ except ImportError: triton_available = False +from test_utils import skip_if_rocm + from torchao.utils import skip_if_compute_capability_less_than @@ -20,6 +22,7 @@ @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestFP8Gemm(TestCase): @skip_if_compute_capability_less_than(9.0) + @skip_if_rocm("ROCm development in progress") def test_gemm_split_k(self): dtype = torch.float16 qdtype = torch.float8_e4m3fn diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 3eb9b0a2c5..47020d6b26 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -13,6 +13,7 @@ dequantize_blockwise, quantize_blockwise, ) +from test_utils import skip_if_rocm from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, ) +@skip_if_rocm("ROCm development in progress") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index ebdf2281e0..c21922b631 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -3,6 +3,7 @@ import pytest import torch +from test_utils import skip_if_rocm from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests @@ -45,6 +46,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq(self): output_ref = self.model(self.input) for group_size in [-1, 128]: @@ -66,6 +68,7 @@ def test_marlin_qqq(self): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): model_copy = copy.deepcopy(self.model) model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 4da7304a24..a78940656b 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -2,6 +2,7 @@ import pytest import torch +from test_utils import skip_if_rocm from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests @@ -37,6 +38,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) @@ -48,13 +50,13 @@ def test_quant_sparse_marlin_layout_eager(self): # Sparse + quantized quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose( dense_result, sparse_result, atol=3e-1 ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) diff --git a/test/test_ops.py b/test/test_ops.py index 26671ddf40..5a60a50e00 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,6 +19,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + if is_fbcode(): pytest.skip( "Skipping the test in fbcode since we don't have TARGET file for kernels" diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py index 6510adaea3..93f842b2d8 100644 --- a/test/test_s8s4_linear_cutlass.py +++ b/test/test_s8s4_linear_cutlass.py @@ -7,6 +7,9 @@ from torchao.quantization.utils import group_quantize_tensor_symmetric from torchao.utils import compute_max_diff +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] S8S4_LINEAR_CUTLASS_SIZE_MNK = [ diff --git a/test/test_utils.py b/test/test_utils.py index 77a8b39aae..d4bcb7ffe0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,11 +1,40 @@ +import functools import unittest from unittest.mock import patch +import pytest import torch from torchao.utils import TorchAOBaseTensor, torch_version_at_least +def skip_if_rocm(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.version.hip is not None: + skip_message = "Skipping the test in ROCm" + if message: + skip_message += f": {message}" + pytest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_rocm and @skip_if_rocm() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + class TestTorchVersionAtLeast(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [