Skip to content

Commit 22bec74

Browse files
authored
Update hardware check conditions (#1356)
1 parent aeb1944 commit 22bec74

13 files changed

+106
-84
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
int8_weight_only,
1818
)
1919
from torchao.quantization.quant_primitives import MappingType
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
21-
22-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_5,
22+
TORCH_VERSION_AT_LEAST_2_6,
23+
is_sm_at_least_89,
24+
)
2325

2426

2527
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
@@ -42,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
4244
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
4345
)
4446

45-
if is_cuda_8_9:
47+
if is_sm_at_least_89():
4648
base_functions.append(float8_weight_only())
4749

4850
return base_functions

test/dtypes/test_affine_quantized_float.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@
3737
MappingType,
3838
choose_qparams_affine,
3939
)
40+
from torchao.utils import (
41+
is_sm_at_least_89,
42+
is_sm_at_least_90,
43+
)
4044

4145
random.seed(0)
4246
torch.manual_seed(0)
4347

44-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
45-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
46-
4748

4849
class ToyLinearModel(torch.nn.Module):
4950
def __init__(self, in_features, out_features):
@@ -59,12 +60,14 @@ def forward(self, x):
5960

6061
class TestAffineQuantizedFloat8Compile(InductorTestCase):
6162
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
62-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
63+
@unittest.skipIf(
64+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
65+
)
6366
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6467
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
6568
@common_utils.parametrize("compile", [True, False])
6669
@common_utils.parametrize(
67-
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
70+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
6871
)
6972
# Inputs are (M,..), K, N
7073
@common_utils.parametrize(
@@ -134,20 +137,26 @@ def test_fp8_linear_variants(
134137
compute_error(output_original, output_quantized) > 20
135138
), f"Quantization error is too high got a SQNR of {error}"
136139

137-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
140+
@unittest.skipIf(
141+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
142+
)
138143
def test_invalid_granularity(self):
139144
with pytest.raises(ValueError, match="Invalid granularity specification"):
140145
float8_dynamic_activation_float8_weight(granularity="invalid")
141146

142-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
147+
@unittest.skipIf(
148+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
149+
)
143150
def test_mismatched_granularity(self):
144151
with pytest.raises(
145152
ValueError,
146153
match="Different granularities for activation and weight are not supported",
147154
):
148155
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
149156

150-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
157+
@unittest.skipIf(
158+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
159+
)
151160
def test_unsupported_granularity(self):
152161
class UnsupportedGranularity:
153162
pass
@@ -158,7 +167,9 @@ class UnsupportedGranularity:
158167
)
159168

160169
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
161-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
170+
@unittest.skipIf(
171+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
172+
)
162173
def test_per_row_with_float32(self):
163174
with pytest.raises(
164175
AssertionError,
@@ -170,7 +181,9 @@ def test_per_row_with_float32(self):
170181
)
171182

172183
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
173-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
184+
@unittest.skipIf(
185+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
186+
)
174187
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
175188
def test_serialization(self, mode: str):
176189
# Create and quantize the model
@@ -240,7 +253,9 @@ def test_serialization(self, mode: str):
240253
), f"Scales do not match for {layer_name}"
241254

242255
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
243-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
256+
@unittest.skipIf(
257+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
258+
)
244259
def test_fp8_weight_dimension_warning(self):
245260
# Create model with incompatible dimensions (not multiples of 16)
246261
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights

test/float8/test_base.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
import torch
1515
import torch.nn as nn
1616

17-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
17+
from torchao.utils import (
18+
TORCH_VERSION_AT_LEAST_2_5,
19+
is_sm_at_least_89,
20+
is_sm_at_least_90,
21+
)
1822

1923
if not TORCH_VERSION_AT_LEAST_2_5:
2024
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -60,10 +64,6 @@
6064
torch.manual_seed(0)
6165

6266

63-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
64-
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
65-
66-
6767
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
6868
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
6969
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -219,7 +219,7 @@ def test_axiswise_reshape(self):
219219
],
220220
)
221221
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
222-
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
222+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
223223
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
224224
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
225225
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
@@ -333,7 +333,9 @@ def _test_linear_impl(
333333
# verify initialization flags got updated
334334
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
335335

336-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
336+
@pytest.mark.parametrize(
337+
"emulate", [True, False] if is_sm_at_least_89() else [True]
338+
)
337339
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
338340
@pytest.mark.parametrize(
339341
"scaling_type_input",
@@ -415,7 +417,9 @@ def test_linear_from_recipe(
415417
config,
416418
)
417419

418-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
420+
@pytest.mark.parametrize(
421+
"emulate", [True, False] if is_sm_at_least_89() else [True]
422+
)
419423
@pytest.mark.parametrize(
420424
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
421425
)
@@ -462,7 +466,9 @@ def test_autocast_outputs(
462466
@pytest.mark.parametrize(
463467
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
464468
)
465-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
469+
@pytest.mark.parametrize(
470+
"emulate", [True, False] if is_sm_at_least_89() else [True]
471+
)
466472
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
467473
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
468474
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
@@ -523,15 +529,15 @@ def test_repr(self):
523529
s = m.__repr__()
524530
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s
525531

526-
@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
532+
@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
527533
def test_inference_mode(self):
528534
x = torch.randn(32, 32, device="cuda")
529535
m = nn.Sequential(nn.Linear(32, 32)).cuda()
530536
m = convert_to_float8_training(m)
531537
with torch.inference_mode(mode=True):
532538
m(x)
533539

534-
@unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
540+
@unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available")
535541
def test_quantize(self):
536542
x = torch.randn(32, 32, device="cuda")
537543
m = nn.Sequential(nn.Linear(32, 32)).cuda()
@@ -549,7 +555,7 @@ def test_quantize(self):
549555

550556
class TestScaledMM:
551557
@unittest.skipIf(
552-
not is_cuda_8_9,
558+
not is_sm_at_least_89(),
553559
"CUDA not available",
554560
)
555561
@pytest.mark.parametrize(
@@ -594,7 +600,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
594600
atol, rtol = 3e-3, 3e-3
595601
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
596602

597-
@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
603+
@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")
598604
def test_different_configs_error(self):
599605
x_fp32 = torch.randn(16, 16, device="cuda")
600606
x_scale = torch.tensor(1.0, device="cuda")
@@ -630,7 +636,7 @@ def test_different_configs_error(self):
630636
a @ b
631637

632638
@unittest.skipIf(
633-
not is_cuda_8_9,
639+
not is_sm_at_least_89(),
634640
"CUDA not available",
635641
)
636642
@pytest.mark.parametrize(

test/float8/test_compile.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
14+
from torchao.utils import (
15+
TORCH_VERSION_AT_LEAST_2_5,
16+
is_sm_at_least_89,
17+
is_sm_at_least_90,
18+
)
1519

1620
if not TORCH_VERSION_AT_LEAST_2_5:
1721
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -46,10 +50,6 @@
4650
from torchao.float8.float8_utils import e4m3_dtype
4751
from torchao.testing.float8.test_utils import get_test_float8_linear_config
4852

49-
# TODO(future PR): standardize IS_H100 with the rest of the codebase
50-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
51-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
52-
5353

5454
def _test_compile_base(
5555
backend: str,
@@ -99,7 +99,7 @@ def _test_compile_base(
9999
"scaling_type_grad_output",
100100
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
101101
)
102-
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
102+
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
103103
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
104104
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
105105
def test_eager_only(
@@ -126,7 +126,7 @@ def test_eager_only(
126126

127127

128128
@pytest.mark.parametrize("fullgraph", [True])
129-
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
129+
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
130130
@pytest.mark.parametrize(
131131
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
132132
)
@@ -177,7 +177,7 @@ def test_aot_eager(
177177
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
178178
)
179179
@unittest.skipIf(
180-
not torch.cuda.is_available() or not is_cuda_8_9,
180+
not torch.cuda.is_available() or not is_sm_at_least_89(),
181181
"CUDA with float8 support not available",
182182
)
183183
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@@ -215,7 +215,9 @@ def test_inductor_from_config_params(
215215
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
216216
],
217217
)
218-
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
218+
@unittest.skipIf(
219+
not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available"
220+
)
219221
def test_inductor_from_recipe(recipe_name):
220222
torch._dynamo.reset()
221223
config = recipe_name_to_linear_config(recipe_name)
@@ -253,7 +255,7 @@ def forward(self, x):
253255

254256
# TODO(future): figure out why the test below fails on CUDA capability 8.9
255257
@unittest.skipIf(
256-
not torch.cuda.is_available() or not is_H100,
258+
not torch.cuda.is_available() or not is_sm_at_least_90(),
257259
"CUDA with capability 9.0 or greater not available",
258260
)
259261
def test_float8_with_graph_break_in_the_middle(self):
@@ -269,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self):
269271
torch.testing.assert_close(y_eager, y_compiled)
270272

271273
@unittest.skipIf(
272-
not torch.cuda.is_available() or not is_cuda_8_9,
274+
not torch.cuda.is_available() or not is_sm_at_least_89(),
273275
"CUDA with float8 support not available",
274276
)
275277
def test_float8_graph_input(self):
@@ -293,7 +295,7 @@ def to_float(x):
293295
torch.testing.assert_close(y2_eager, y2_compiled)
294296

295297
@unittest.skipIf(
296-
not torch.cuda.is_available() or not is_cuda_8_9,
298+
not torch.cuda.is_available() or not is_sm_at_least_89(),
297299
"CUDA with float8 support not available",
298300
)
299301
def test_float8_graph_output(self):
@@ -323,7 +325,7 @@ def test_float8_graph_output(self):
323325

324326

325327
@unittest.skipIf(
326-
not torch.cuda.is_available() or not is_cuda_8_9,
328+
not torch.cuda.is_available() or not is_sm_at_least_89(),
327329
"CUDA with float8 support not available",
328330
)
329331
def test_sync_amax_func():
@@ -364,7 +366,7 @@ def __exit__(self, *args):
364366

365367

366368
@unittest.skipIf(
367-
not torch.cuda.is_available() or not is_cuda_8_9,
369+
not torch.cuda.is_available() or not is_sm_at_least_89(),
368370
"CUDA with float8 support not available",
369371
)
370372
def test_sync_amax_func_cuda_graph_success():
@@ -396,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success():
396398

397399

398400
@unittest.skipIf(
399-
not is_cuda_8_9,
401+
not is_sm_at_least_89(),
400402
"CUDA not available",
401403
)
402404
@pytest.mark.parametrize(

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
9+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89
1010

1111
if not TORCH_VERSION_AT_LEAST_2_5:
1212
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -40,8 +40,7 @@
4040
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
4141
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
4242

43-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
44-
if not is_cuda_8_9:
43+
if not is_sm_at_least_89():
4544
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
4645

4746

test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
6+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89
77

88
if not TORCH_VERSION_AT_LEAST_2_5:
99
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -30,8 +30,7 @@
3030
from torchao.float8.float8_tensor import GemmInputRole
3131
from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only
3232

33-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
34-
if not is_cuda_8_9:
33+
if not is_sm_at_least_89():
3534
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
3635

3736

0 commit comments

Comments
 (0)