Skip to content

Skip Unit Tests for ROCm CI #1563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added test/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import torch
from test_utils import skip_if_rocm
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_tensor_core_layout_transpose(self):
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@skip_if_rocm("ROCm development in progress")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(True, True, "cuda", True)
Expand Down Expand Up @@ -168,6 +170,7 @@ def apply_uint6_weight_only_quant(linear):

deregister_aqt_quantized_linear_dispatch(dispatch_condition)

@skip_if_rocm("ROCm development in progress")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
Expand All @@ -180,6 +183,7 @@ class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.bfloat16]

@skip_if_rocm("ROCm development in progress")
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, device, dtype):
Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import torch
from test_utils import skip_if_rocm
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
Expand Down Expand Up @@ -108,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@skip_if_rocm("ROCm development in progress")
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
Expand Down
3 changes: 3 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from test_utils import skip_if_rocm

from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
Expand Down Expand Up @@ -423,6 +425,7 @@ def test_linear_from_config_params(
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skip_if_rocm("ROCm development in progress")
def test_linear_from_recipe(
self,
recipe_name,
Expand Down
2 changes: 2 additions & 0 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import torch
from test_utils import skip_if_rocm

from torchao.quantization import (
MappingType,
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self):
ref_dot_product_error=0.000704,
)

@skip_if_rocm("ROCm development in progress")
def test_hqq_plain_4bit(self):
self._test_hqq(
dtype=torch.uint4,
Expand Down
7 changes: 7 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
except ModuleNotFoundError:
has_gemlite = False

from test_utils import skip_if_rocm

logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -566,6 +568,7 @@ def test_per_token_linear_cpu(self):
self._test_per_token_linear_impl("cpu", dtype)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_per_token_linear_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_per_token_linear_impl("cuda", dtype)
Expand Down Expand Up @@ -684,6 +687,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -703,6 +707,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down Expand Up @@ -896,6 +901,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -915,6 +921,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"ROCm enablement in progress" would be a better comment.

def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down
2 changes: 2 additions & 0 deletions test/kernel/test_galore_downproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from galore_test_utils import make_data
from test_utils import skip_if_rocm

from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
Expand All @@ -29,6 +30,7 @@

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
@skip_if_rocm("ROCm development in progress")
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_

from test_utils import skip_if_rocm


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand Down Expand Up @@ -113,6 +115,7 @@ def test_awq_loading(device, qdtype):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_rocm("ROCm development in progress")
def test_save_weights_only():
dataset_size = 100
l1, l2, l3 = 512, 256, 128
Expand Down
2 changes: 2 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
except ImportError:
lpmm = None

from test_utils import skip_if_rocm

_DEVICES = get_available_devices()

Expand Down Expand Up @@ -112,6 +113,7 @@ class TestOptim(TestCase):
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
@skip_if_rocm("ROCm development in progress")
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
except ImportError:
triton_available = False

from test_utils import skip_if_rocm

from torchao.utils import skip_if_compute_capability_less_than


@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
@skip_if_compute_capability_less_than(9.0)
@skip_if_rocm("ROCm development in progress")
def test_gemm_split_k(self):
dtype = torch.float16
qdtype = torch.float8_e4m3fn
Expand Down
2 changes: 2 additions & 0 deletions test/quantization/test_galore_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
dequantize_blockwise,
quantize_blockwise,
)
from test_utils import skip_if_rocm

from torchao.prototype.galore.kernels import (
triton_dequant_blockwise,
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
"dim1,dim2,dtype,signed,blocksize",
TEST_CONFIGS,
)
@skip_if_rocm("ROCm development in progress")
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01

Expand Down
3 changes: 3 additions & 0 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from test_utils import skip_if_rocm
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

Expand Down Expand Up @@ -45,6 +46,7 @@ def setUp(self):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_marlin_qqq(self):
output_ref = self.model(self.input)
for group_size in [-1, 128]:
Expand All @@ -66,6 +68,7 @@ def test_marlin_qqq(self):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_marlin_qqq_compile(self):
model_copy = copy.deepcopy(self.model)
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
Expand Down
4 changes: 3 additions & 1 deletion test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from test_utils import skip_if_rocm
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

Expand Down Expand Up @@ -37,6 +38,7 @@ def setUp(self):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)
Expand All @@ -48,13 +50,13 @@ def test_quant_sparse_marlin_layout_eager(self):
# Sparse + quantized
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
sparse_result = self.model(self.input)

assert torch.allclose(
dense_result, sparse_result, atol=3e-1
), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)
Expand Down
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

if is_fbcode():
pytest.skip(
"Skipping the test in fbcode since we don't have TARGET file for kernels"
Expand Down
3 changes: 3 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
Expand Down
29 changes: 29 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
import functools
import unittest
from unittest.mock import patch

import pytest
import torch

from torchao.utils import TorchAOBaseTensor, torch_version_at_least


def skip_if_rocm(message=None):
"""Decorator to skip tests on ROCm platform with custom message.

Args:
message (str, optional): Additional information about why the test is skipped.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch.version.hip is not None:
skip_message = "Skipping the test in ROCm"
if message:
skip_message += f": {message}"
pytest.skip(skip_message)
return func(*args, **kwargs)

return wrapper

# Handle both @skip_if_rocm and @skip_if_rocm() syntax
if callable(message):
func = message
message = None
return decorator(func)
return decorator


class TestTorchVersionAtLeast(unittest.TestCase):
def test_torch_version_at_least(self):
test_cases = [
Expand Down
Loading