Skip to content

Commit 80360b8

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Enable DeepGEMM in quantize bench (#3745)
Summary: Pull Request resolved: #3745 X-link: facebookresearch/FBGEMM#826 Add deepgemm grouped_gemm to quantize bench. Reviewed By: jiawenliu64 Differential Revision: D70223879 fbshipit-source-id: fbfa136e10a9dcfa460dd5d95f47b5b5ed683824
1 parent deec0cd commit 80360b8

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,13 @@ def benchmark_grouped(
169169
# Compute the output given quantized values.
170170
output = quantize_op.compute(*quantized_vals)
171171
# Some kernels may pad output, just take the first m values of each row.
172-
output = [o[: m[i]] for i, o in enumerate(output)]
172+
if isinstance(output, torch.Tensor) and output.ndim == 2:
173+
# Output is stacked and needs to be split.
174+
output = torch.split(output, m, dim=0)
175+
else:
176+
# Otherwise output may be padded or require unbinding.
177+
output = [o[: m[i]] for i, o in enumerate(output)]
173178
# Compare the quantize op output to reference as a sanity check.
174-
175179
for i in range(num_groups):
176180
metrics.sim += float(
177181
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
scale_fp8_row,
2222
triton_quantize_fp8_row,
2323
)
24+
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
25+
grouped_gemm_fp8_rowwise,
26+
)
2427
from tinygemm.utils import group_quantize_tensor
2528

2629
if torch.cuda.is_available() and torch.version.cuda:
@@ -35,6 +38,14 @@
3538
except ImportError:
3639
MARLIN_ENABLED = False
3740

41+
try:
42+
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
43+
44+
DEEPGEMM_ENABLED = True
45+
except ImportError:
46+
DEEPGEMM_ENABLED = False
47+
48+
3849
# Machete is also only supported internally at Meta for now.
3950
try:
4051
from machete.machete import machete_gemm
@@ -712,6 +723,106 @@ def cuda(self) -> bool:
712723
return True
713724

714725

726+
@register_quantize_op
727+
class FP8TritonStackedGroupedGemm(QuantizeOpBase):
728+
"""
729+
FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton.
730+
"""
731+
732+
def preprocess(self, x, w):
733+
m_values = [i.shape[0] for i in x]
734+
# Convert m_values into offsets into grouped tensor.
735+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
736+
# Quantize weights.
737+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
738+
# Group weights as single tensor.
739+
wq = torch.concat(wq, dim=0).contiguous()
740+
w_scale = torch.concat(w_scale, dim=0).contiguous()
741+
# Also view input as flattened.
742+
x = torch.concat(x, dim=0).contiguous()
743+
# Return processed tensors.
744+
return x, wq, w_scale, m_sizes
745+
746+
def quantize(self, x, wq, w_scale, m_sizes):
747+
B = x.shape[0]
748+
xq, x_scale = triton_quantize_fp8_row(x)
749+
x_scale = x_scale.view(B, -1)
750+
return xq, wq, x_scale, w_scale, m_sizes
751+
752+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
753+
return grouped_gemm_fp8_rowwise(xq, wq, m_sizes, x_scale, w_scale)
754+
755+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
756+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
757+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
758+
759+
@property
760+
def name(self) -> str:
761+
return "triton_grouped_stacked"
762+
763+
@property
764+
def hip(self) -> bool:
765+
return True
766+
767+
@property
768+
def cuda(self) -> bool:
769+
return True
770+
771+
772+
@register_quantize_op
773+
class DeepGemmStacked(QuantizeOpBase):
774+
"""
775+
FP8 grouped matmul with blockwise scaling implemented with DeepGemm.
776+
"""
777+
778+
def preprocess(self, x, w):
779+
m_values = [i.shape[0] for i in x]
780+
# Convert m_values into offsets into grouped tensor.
781+
indices = torch.arange(len(m_values))
782+
m_indices = indices.repeat_interleave(torch.tensor(m_values)).to(
783+
device=x[0].device, dtype=torch.int
784+
)
785+
# Quantize weights.
786+
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
787+
# Group weights as single tensor.
788+
wq = torch.stack(wq, dim=0).contiguous()
789+
w_scale = torch.stack(w_scale, dim=0).contiguous()
790+
# Also view input as flattened.
791+
x = torch.concat(x, dim=0).contiguous()
792+
# Return processed tensors.
793+
return x, wq, w_scale, m_indices
794+
795+
def quantize(self, x, wq, w_scale, m_indices):
796+
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
797+
return xq, wq, x_scale, w_scale, m_indices
798+
799+
def compute(self, xq, wq, x_scale, w_scale, m_indices):
800+
# Preallocate output.
801+
out = torch.empty(
802+
[xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16
803+
)
804+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
805+
(xq, x_scale), (wq, w_scale), out, m_indices
806+
)
807+
return out
808+
809+
def quantize_and_compute(self, x, wq, w_scale, m_indices):
810+
xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices)
811+
return self.compute(xq, wq, x_scale, w_scale, m_indices)
812+
813+
@property
814+
def name(self) -> str:
815+
return "deepgemm_stacked"
816+
817+
@property
818+
def hip(self) -> bool:
819+
return False
820+
821+
@property
822+
def cuda(self) -> bool:
823+
return DEEPGEMM_ENABLED
824+
825+
715826
@register_quantize_op
716827
class BF16GroupedGemm(QuantizeOpBase):
717828
"""

0 commit comments

Comments
 (0)