Skip to content

Commit 1cfe69e

Browse files
authored
Enable float8 CI on sm89 (#587)
1 parent f6595ac commit 1cfe69e

File tree

7 files changed

+76
-31
lines changed

7 files changed

+76
-31
lines changed

.github/workflows/float8_test.yml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Run Float8 Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
11+
concurrency:
12+
group: float8_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
13+
cancel-in-progress: true
14+
15+
env:
16+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
17+
18+
jobs:
19+
test:
20+
strategy:
21+
fail-fast: false
22+
matrix:
23+
include:
24+
- name: SM-89
25+
runs-on: amz2023.linux.g6.4xlarge.experimental.nvidia.gpu
26+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
27+
gpu-arch-type: "cuda"
28+
gpu-arch-version: "12.1"
29+
30+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
31+
with:
32+
timeout: 60
33+
runner: ${{ matrix.runs-on }}
34+
gpu-arch-type: ${{ matrix.gpu-arch-type }}
35+
gpu-arch-version: ${{ matrix.gpu-arch-version }}
36+
script: |
37+
conda create -n venv python=3.9 -y
38+
conda activate venv
39+
echo "::group::Install newer objcopy that supports --set-section-alignment"
40+
yum install -y devtoolset-10-binutils
41+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
42+
python -m pip install --upgrade pip
43+
pip install ${{ matrix.torch-spec }}
44+
pip install -r dev-requirements.txt
45+
pip install .
46+
pytest test/float8 --verbose -s

test/float8/test_base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
random.seed(0)
5555
torch.manual_seed(0)
5656

57-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5857

58+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
5959

6060
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
6161
assert torch.all(a._data == b._data).item(), "scales are not identical"
@@ -223,7 +223,7 @@ def _test_linear_impl(
223223
# verify initialization flags got updated
224224
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
225225

226-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
226+
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
227227
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
228228
@pytest.mark.parametrize(
229229
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
@@ -271,7 +271,7 @@ def test_linear(
271271
config,
272272
)
273273

274-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
274+
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
275275
@pytest.mark.parametrize(
276276
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
277277
)
@@ -325,7 +325,7 @@ def test_autocast_outputs(
325325
@pytest.mark.parametrize(
326326
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
327327
)
328-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
328+
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
329329
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
330330
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
331331
emulate = (
@@ -393,7 +393,7 @@ def test_repr(self):
393393

394394
class TestScaledMM:
395395
@unittest.skipIf(
396-
not is_H100,
396+
not is_cuda_8_9,
397397
"CUDA not available",
398398
)
399399
@pytest.mark.parametrize(
@@ -437,7 +437,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
437437
atol, rtol = 2e-3, 2e-3
438438
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
439439

440-
@unittest.skipIf(not is_H100, "CUDA not available")
440+
@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
441441
def test_different_configs_error(self):
442442
x_fp32 = torch.randn(16, 16, device="cuda")
443443
x_scale = torch.tensor(1.0, device="cuda")
@@ -473,7 +473,7 @@ def test_different_configs_error(self):
473473
a @ b
474474

475475
@unittest.skipIf(
476-
not is_H100,
476+
not is_cuda_8_9,
477477
"CUDA not available",
478478
)
479479
@pytest.mark.parametrize(

test/float8/test_compile.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torch._dynamo.testing import CompileCounterWithBackend
3434

3535
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
36-
36+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
3737

3838
def _test_compile_base(
3939
backend: str,
@@ -77,7 +77,7 @@ def _test_compile_base(
7777
@pytest.mark.parametrize(
7878
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
7979
)
80-
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
80+
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
8181
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
8282
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
8383
def test_eager_only(
@@ -104,7 +104,7 @@ def test_eager_only(
104104

105105

106106
@pytest.mark.parametrize("fullgraph", [True])
107-
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
107+
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
108108
@pytest.mark.parametrize(
109109
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
110110
)
@@ -150,7 +150,7 @@ def test_aot_eager(
150150
@pytest.mark.parametrize(
151151
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
152152
)
153-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
153+
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
154154
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
155155
def test_inductor(
156156
fullgraph,
@@ -210,7 +210,7 @@ def test_float8_with_graph_break_in_the_middle(self):
210210
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
211211
torch.testing.assert_close(y_eager, y_compiled)
212212

213-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
213+
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
214214
def test_float8_graph_input(self):
215215
"""Test that having Float8Tensor object as a graph input"""
216216

@@ -231,7 +231,7 @@ def to_float(x):
231231
)
232232
torch.testing.assert_close(y2_eager, y2_compiled)
233233

234-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
234+
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
235235
def test_float8_graph_output(self):
236236
"""Test that having Float8Tensor object as a graph output works"""
237237
cnts = CompileCounterWithBackend("inductor")
@@ -258,7 +258,7 @@ def test_float8_graph_output(self):
258258
)
259259

260260

261-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
261+
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
262262
def test_sync_amax_func():
263263
torch._dynamo.reset()
264264
cnts = CompileCounterWithBackend("inductor")
@@ -296,7 +296,7 @@ def __exit__(self, *args):
296296
sys.stderr = self.sys_stderr
297297

298298

299-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
299+
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
300300
def test_sync_amax_func_cuda_graph_success():
301301
torch._dynamo.reset()
302302
with capture_stderr() as stderr:

test/float8/test_fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False):
197197
if not torch.cuda.is_available():
198198
warnings.warn("CUDA not available, running in emulation_mode")
199199
emulate = True
200-
elif torch.cuda.get_device_capability() < (9, 0):
200+
elif torch.cuda.get_device_capability() < (8, 9):
201201
warnings.warn(
202-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
202+
f"CUDA capability {torch.cuda.get_device_capability()} < (8.9), running in emulation mode"
203203
)
204204
emulate = True
205205

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
TransformerBlock,
3737
)
3838

39-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
40-
if not is_H100:
39+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
40+
if not is_cuda_8_9:
4141
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
4242

4343
class TestFloat8Common:

test/float8/test_inference_flows.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
torch.manual_seed(0)
3939

4040
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
41-
41+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
4242

4343
class FeedForward(nn.Module):
4444
def __init__(self) -> None:
@@ -74,8 +74,8 @@ def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
7474
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
7575
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
7676
@unittest.skipIf(
77-
not torch.cuda.is_available() or not is_H100,
78-
"CUDA not available or on non H100 machine",
77+
not torch.cuda.is_available() or not is_cuda_8_9,
78+
"CUDA not available or machine does not support SM89",
7979
)
8080
def test_dynamic_fp8_mlp(self, compile_backend, dtype):
8181
original_mlp = FeedForward().to("cuda", dtype=dtype)
@@ -109,8 +109,8 @@ def test_dynamic_fp8_mlp(self, compile_backend, dtype):
109109
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
110110
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
111111
@unittest.skipIf(
112-
not torch.cuda.is_available() or not is_H100,
113-
"CUDA not available or on non H100 machine",
112+
not torch.cuda.is_available() or not is_cuda_8_9,
113+
"CUDA not available or machine does not support SM89",
114114
)
115115
def test_static_fp8_mlp(self, compile_backend, dtype):
116116
original_mlp = FeedForward().to("cuda", dtype=dtype)
@@ -150,8 +150,8 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
150150
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
151151
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
152152
@unittest.skipIf(
153-
not torch.cuda.is_available() or not is_H100,
154-
"CUDA not available or on non H100 machine",
153+
not torch.cuda.is_available() or not is_cuda_8_9,
154+
"CUDA not available or machine does not support SM89",
155155
)
156156
def test_weight_only_fp8_mlp(self, compile_backend, dtype):
157157
original_mlp = FeedForward().to("cuda", dtype=dtype)
@@ -205,8 +205,8 @@ def train(self, model: nn.Module, dtype: torch.dtype):
205205

206206
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
207207
@unittest.skipIf(
208-
not torch.cuda.is_available() or not is_H100,
209-
"CUDA not available or on non H100 machine",
208+
not torch.cuda.is_available() or not is_cuda_8_9,
209+
"CUDA not available or machine does not support SM89",
210210
)
211211
def test_fp8_save_and_load(self, dtype: torch.dtype):
212212
# Initialize FP8 model

test/float8/test_numerics_integration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
)
2828
from torchao.float8.float8_utils import compute_error, IS_ROCM
2929

30-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
31-
30+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
3231

3332
torch.manual_seed(0)
3433

@@ -89,7 +88,7 @@ class TestFloat8NumericsIntegrationTest:
8988
"scaling_type_grad_output",
9089
[ScalingType.DELAYED, ScalingType.DYNAMIC],
9190
)
92-
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
91+
@pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine")
9392
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
9493
def test_encoder_fw_bw(
9594
self,

0 commit comments

Comments
 (0)