|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +import unittest |
| 10 | +from typing import Tuple |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +if torch.cuda.is_available(): |
| 15 | + from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row |
| 16 | + from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( |
| 17 | + grouped_gemm, |
| 18 | + grouped_gemm_fp8_rowwise, |
| 19 | + ) |
| 20 | + from fbgemm_gpu.experimental.gemm.triton_gemm.utils import HAS_TMA_DESC |
| 21 | + |
| 22 | + |
| 23 | +@unittest.skipIf( |
| 24 | + not torch.cuda.is_available() |
| 25 | + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 |
| 26 | + or not HAS_TMA_DESC, |
| 27 | + "Skip when H100 or TMA is not available", |
| 28 | +) |
| 29 | +class TestGroupedGEMM(unittest.TestCase): |
| 30 | + def setUp(self) -> None: |
| 31 | + torch.manual_seed(0) |
| 32 | + |
| 33 | + def test_grouped_gemm_fp8_rowwise(self) -> None: |
| 34 | + def _test_grouped_gemm_fp8_rowwise( |
| 35 | + shape: Tuple[int, int, int, int], |
| 36 | + device: torch.device, |
| 37 | + ) -> None: |
| 38 | + G, M, N, K = shape |
| 39 | + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) |
| 40 | + b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) |
| 41 | + m_offsets, _ = torch.sort( |
| 42 | + torch.randint(low=0, high=M, size=[G], device=device, dtype=torch.int32) |
| 43 | + ) |
| 44 | + m_offsets[G - 1] = M |
| 45 | + |
| 46 | + a_fp8, a_scale = quantize_fp8_row(a) |
| 47 | + b_fp8, b_scale = quantize_fp8_row(b) |
| 48 | + |
| 49 | + result = grouped_gemm_fp8_rowwise( |
| 50 | + a_fp8, |
| 51 | + b_fp8, |
| 52 | + m_offsets, |
| 53 | + a_scale, |
| 54 | + b_scale, |
| 55 | + ) |
| 56 | + self.assertTrue(result.shape == (M, N)) |
| 57 | + |
| 58 | + expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) |
| 59 | + # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation. |
| 60 | + for g in range(G): |
| 61 | + m_start = 0 if g == 0 else m_offsets[g - 1] |
| 62 | + m_end = m_offsets[g] |
| 63 | + n_start = g * N |
| 64 | + n_end = (g + 1) * N |
| 65 | + |
| 66 | + expected_result[m_start:m_end, :] = ( |
| 67 | + a_fp8[m_start:m_end, :].to(torch.float32) |
| 68 | + @ b_fp8[n_start:n_end, :].to(torch.float32).T |
| 69 | + * a_scale[m_start:m_end][:, None] |
| 70 | + * b_scale[n_start:n_end][None, :] |
| 71 | + ).to(torch.bfloat16) |
| 72 | + |
| 73 | + torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2) |
| 74 | + |
| 75 | + _test_grouped_gemm_fp8_rowwise((16, 512, 256, 256), torch.device("cuda")) |
| 76 | + _test_grouped_gemm_fp8_rowwise((8, 512, 256, 256), torch.device("cuda")) |
| 77 | + _test_grouped_gemm_fp8_rowwise((4, 512, 256, 256), torch.device("cuda")) |
| 78 | + _test_grouped_gemm_fp8_rowwise((2, 512, 256, 256), torch.device("cuda")) |
| 79 | + # TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug. |
| 80 | + # _test_grouped_gemm_fp8_rowwise((1, 512, 256, 256), torch.device("cuda")) |
| 81 | + |
| 82 | + def test_grouped_gemm_bf16(self) -> None: |
| 83 | + def _test_grouped_gemm_bf16( |
| 84 | + shape: Tuple[int, int, int, int], |
| 85 | + device: torch.device, |
| 86 | + ) -> None: |
| 87 | + G, M, N, K = shape |
| 88 | + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) |
| 89 | + b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) |
| 90 | + m_offsets, _ = torch.sort( |
| 91 | + torch.randint(low=0, high=M, size=[G], device=device, dtype=torch.int32) |
| 92 | + ) |
| 93 | + m_offsets[G - 1] = M |
| 94 | + |
| 95 | + result = grouped_gemm( |
| 96 | + a, |
| 97 | + b, |
| 98 | + m_offsets, |
| 99 | + ) |
| 100 | + self.assertTrue(result.shape == (M, N)) |
| 101 | + |
| 102 | + expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) |
| 103 | + for g in range(G): |
| 104 | + m_start = 0 if g == 0 else m_offsets[g - 1] |
| 105 | + m_end = m_offsets[g] |
| 106 | + expected_result[m_start:m_end, :] = ( |
| 107 | + a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T |
| 108 | + ) |
| 109 | + |
| 110 | + torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) |
| 111 | + |
| 112 | + _test_grouped_gemm_bf16((16, 512, 256, 256), torch.device("cuda")) |
| 113 | + _test_grouped_gemm_bf16((8, 512, 256, 256), torch.device("cuda")) |
| 114 | + _test_grouped_gemm_bf16((4, 512, 256, 256), torch.device("cuda")) |
| 115 | + _test_grouped_gemm_bf16((2, 512, 256, 256), torch.device("cuda")) |
| 116 | + # TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug. |
| 117 | + # _test_grouped_gemm_bf16((1, 512, 256, 256), torch.device("cuda")) |
0 commit comments