Skip to content

Commit 9158d3c

Browse files
committed
skip if no cuda
1 parent 174cce6 commit 9158d3c

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
quantize_,
2929
)
3030
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
31-
from torchao.testing.utils import skip_if_rocm
31+
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
3232
from torchao.utils import (
3333
TORCH_VERSION_AT_LEAST_2_5,
3434
TORCH_VERSION_AT_LEAST_2_6,
@@ -308,7 +308,8 @@ def test_alias(self, device, dtype):
308308
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
309309
_ = dummy.weight[...]
310310

311-
@common_utils.parametrize("device", ["cuda"] if torch.cuda.is_available() else [])
311+
@skip_if_no_cuda
312+
@common_utils.parametrize("device", ["cuda"])
312313
@common_utils.parametrize("dtype", [torch.bfloat16])
313314
def test_slice(self, device, dtype):
314315
# in_feature not divisible by 1024

torchao/testing/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,20 @@ def wrapper(*args, **kwargs):
9090
return decorator
9191

9292

93+
def skip_if_no_cuda():
94+
import unittest
95+
96+
def decorator(test_func):
97+
def wrapper(*args, **kwargs):
98+
if not torch.cuda.is_available():
99+
raise unittest.SkipTest("No cuda available")
100+
return test_func(*args, **kwargs)
101+
102+
return wrapper
103+
104+
return decorator
105+
106+
93107
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
94108
def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902
95109
for name, value in my_cls.__dict__.items():

0 commit comments

Comments
 (0)