Skip to content

Commit 0c2b7a1

Browse files
levendleefacebook-github-bot
authored andcommitted
Adds Triton based GroupedGEMM implementation. (#3674)
Summary: Pull Request resolved: #3674 X-link: facebookresearch/FBGEMM#751 Adds triton based grouped gemm implementation with on-device shape information. - This implementation limits N and K must be consistent but M can be different across all groups. - Shape annotation: A: [M, K]; B: [N * G, K]; C: [M, N]. Noted G indicates number of groups. This moves experimental code from D69334734 to production. Besides, - It is added as a standalone module to avoid breaking other production moduels. - It doesn't have a benchmark attached to avoid leak confidential information to OSS. Reviewed By: jwfromm Differential Revision: D69364390 fbshipit-source-id: 76c4620f71cbf6cae324af774db29f57ea3d216c
1 parent fc718cf commit 0c2b7a1

File tree

3 files changed

+491
-0
lines changed

3 files changed

+491
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
from typing import Tuple
11+
12+
import torch
13+
14+
if torch.cuda.is_available():
15+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
16+
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
17+
grouped_gemm,
18+
grouped_gemm_fp8_rowwise,
19+
)
20+
from fbgemm_gpu.experimental.gemm.triton_gemm.utils import HAS_TMA_DESC
21+
22+
23+
@unittest.skipIf(
24+
not torch.cuda.is_available()
25+
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9
26+
or not HAS_TMA_DESC,
27+
"Skip when H100 or TMA is not available",
28+
)
29+
class TestGroupedGEMM(unittest.TestCase):
30+
def setUp(self) -> None:
31+
torch.manual_seed(0)
32+
33+
def test_grouped_gemm_fp8_rowwise(self) -> None:
34+
def _test_grouped_gemm_fp8_rowwise(
35+
shape: Tuple[int, int, int, int],
36+
device: torch.device,
37+
) -> None:
38+
G, M, N, K = shape
39+
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
40+
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
41+
m_offsets, _ = torch.sort(
42+
torch.randint(low=0, high=M, size=[G], device=device, dtype=torch.int32)
43+
)
44+
m_offsets[G - 1] = M
45+
46+
a_fp8, a_scale = quantize_fp8_row(a)
47+
b_fp8, b_scale = quantize_fp8_row(b)
48+
49+
result = grouped_gemm_fp8_rowwise(
50+
a_fp8,
51+
b_fp8,
52+
m_offsets,
53+
a_scale,
54+
b_scale,
55+
)
56+
self.assertTrue(result.shape == (M, N))
57+
58+
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
59+
# Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation.
60+
for g in range(G):
61+
m_start = 0 if g == 0 else m_offsets[g - 1]
62+
m_end = m_offsets[g]
63+
n_start = g * N
64+
n_end = (g + 1) * N
65+
66+
expected_result[m_start:m_end, :] = (
67+
a_fp8[m_start:m_end, :].to(torch.float32)
68+
@ b_fp8[n_start:n_end, :].to(torch.float32).T
69+
* a_scale[m_start:m_end][:, None]
70+
* b_scale[n_start:n_end][None, :]
71+
).to(torch.bfloat16)
72+
73+
torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)
74+
75+
_test_grouped_gemm_fp8_rowwise((16, 512, 256, 256), torch.device("cuda"))
76+
_test_grouped_gemm_fp8_rowwise((8, 512, 256, 256), torch.device("cuda"))
77+
_test_grouped_gemm_fp8_rowwise((4, 512, 256, 256), torch.device("cuda"))
78+
_test_grouped_gemm_fp8_rowwise((2, 512, 256, 256), torch.device("cuda"))
79+
# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
80+
# _test_grouped_gemm_fp8_rowwise((1, 512, 256, 256), torch.device("cuda"))
81+
82+
def test_grouped_gemm_bf16(self) -> None:
83+
def _test_grouped_gemm_bf16(
84+
shape: Tuple[int, int, int, int],
85+
device: torch.device,
86+
) -> None:
87+
G, M, N, K = shape
88+
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
89+
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
90+
m_offsets, _ = torch.sort(
91+
torch.randint(low=0, high=M, size=[G], device=device, dtype=torch.int32)
92+
)
93+
m_offsets[G - 1] = M
94+
95+
result = grouped_gemm(
96+
a,
97+
b,
98+
m_offsets,
99+
)
100+
self.assertTrue(result.shape == (M, N))
101+
102+
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
103+
for g in range(G):
104+
m_start = 0 if g == 0 else m_offsets[g - 1]
105+
m_end = m_offsets[g]
106+
expected_result[m_start:m_end, :] = (
107+
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
108+
)
109+
110+
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
111+
112+
_test_grouped_gemm_bf16((16, 512, 256, 256), torch.device("cuda"))
113+
_test_grouped_gemm_bf16((8, 512, 256, 256), torch.device("cuda"))
114+
_test_grouped_gemm_bf16((4, 512, 256, 256), torch.device("cuda"))
115+
_test_grouped_gemm_bf16((2, 512, 256, 256), torch.device("cuda"))
116+
# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
117+
# _test_grouped_gemm_bf16((1, 512, 256, 256), torch.device("cuda"))

0 commit comments

Comments
 (0)