|
21 | 21 | scale_fp8_row,
|
22 | 22 | triton_quantize_fp8_row,
|
23 | 23 | )
|
| 24 | +from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( |
| 25 | + grouped_gemm_fp8_rowwise, |
| 26 | +) |
24 | 27 | from tinygemm.utils import group_quantize_tensor
|
25 | 28 |
|
26 | 29 | if torch.cuda.is_available() and torch.version.cuda:
|
|
35 | 38 | except ImportError:
|
36 | 39 | MARLIN_ENABLED = False
|
37 | 40 |
|
| 41 | +try: |
| 42 | + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous |
| 43 | + |
| 44 | + DEEPGEMM_ENABLED = True |
| 45 | +except ImportError: |
| 46 | + DEEPGEMM_ENABLED = False |
| 47 | + |
| 48 | + |
38 | 49 | # Machete is also only supported internally at Meta for now.
|
39 | 50 | try:
|
40 | 51 | from machete.machete import machete_gemm
|
@@ -712,6 +723,106 @@ def cuda(self) -> bool:
|
712 | 723 | return True
|
713 | 724 |
|
714 | 725 |
|
| 726 | +@register_quantize_op |
| 727 | +class FP8TritonStackedGroupedGemm(QuantizeOpBase): |
| 728 | + """ |
| 729 | + FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton. |
| 730 | + """ |
| 731 | + |
| 732 | + def preprocess(self, x, w): |
| 733 | + m_values = [i.shape[0] for i in x] |
| 734 | + # Convert m_values into offsets into grouped tensor. |
| 735 | + m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device) |
| 736 | + # Quantize weights. |
| 737 | + wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) |
| 738 | + # Group weights as single tensor. |
| 739 | + wq = torch.concat(wq, dim=0).contiguous() |
| 740 | + w_scale = torch.concat(w_scale, dim=0).contiguous() |
| 741 | + # Also view input as flattened. |
| 742 | + x = torch.concat(x, dim=0).contiguous() |
| 743 | + # Return processed tensors. |
| 744 | + return x, wq, w_scale, m_sizes |
| 745 | + |
| 746 | + def quantize(self, x, wq, w_scale, m_sizes): |
| 747 | + B = x.shape[0] |
| 748 | + xq, x_scale = triton_quantize_fp8_row(x) |
| 749 | + x_scale = x_scale.view(B, -1) |
| 750 | + return xq, wq, x_scale, w_scale, m_sizes |
| 751 | + |
| 752 | + def compute(self, xq, wq, x_scale, w_scale, m_sizes): |
| 753 | + return grouped_gemm_fp8_rowwise(xq, wq, m_sizes, x_scale, w_scale) |
| 754 | + |
| 755 | + def quantize_and_compute(self, x, wq, w_scale, m_sizes): |
| 756 | + xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes) |
| 757 | + return self.compute(xq, wq, x_scale, w_scale, m_sizes) |
| 758 | + |
| 759 | + @property |
| 760 | + def name(self) -> str: |
| 761 | + return "triton_grouped_stacked" |
| 762 | + |
| 763 | + @property |
| 764 | + def hip(self) -> bool: |
| 765 | + return True |
| 766 | + |
| 767 | + @property |
| 768 | + def cuda(self) -> bool: |
| 769 | + return True |
| 770 | + |
| 771 | + |
| 772 | +@register_quantize_op |
| 773 | +class DeepGemmStacked(QuantizeOpBase): |
| 774 | + """ |
| 775 | + FP8 grouped matmul with blockwise scaling implemented with DeepGemm. |
| 776 | + """ |
| 777 | + |
| 778 | + def preprocess(self, x, w): |
| 779 | + m_values = [i.shape[0] for i in x] |
| 780 | + # Convert m_values into offsets into grouped tensor. |
| 781 | + indices = torch.arange(len(m_values)) |
| 782 | + m_indices = indices.repeat_interleave(torch.tensor(m_values)).to( |
| 783 | + device=x[0].device, dtype=torch.int |
| 784 | + ) |
| 785 | + # Quantize weights. |
| 786 | + wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w]) |
| 787 | + # Group weights as single tensor. |
| 788 | + wq = torch.stack(wq, dim=0).contiguous() |
| 789 | + w_scale = torch.stack(w_scale, dim=0).contiguous() |
| 790 | + # Also view input as flattened. |
| 791 | + x = torch.concat(x, dim=0).contiguous() |
| 792 | + # Return processed tensors. |
| 793 | + return x, wq, w_scale, m_indices |
| 794 | + |
| 795 | + def quantize(self, x, wq, w_scale, m_indices): |
| 796 | + xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128) |
| 797 | + return xq, wq, x_scale, w_scale, m_indices |
| 798 | + |
| 799 | + def compute(self, xq, wq, x_scale, w_scale, m_indices): |
| 800 | + # Preallocate output. |
| 801 | + out = torch.empty( |
| 802 | + [xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16 |
| 803 | + ) |
| 804 | + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( |
| 805 | + (xq, x_scale), (wq, w_scale), out, m_indices |
| 806 | + ) |
| 807 | + return out |
| 808 | + |
| 809 | + def quantize_and_compute(self, x, wq, w_scale, m_indices): |
| 810 | + xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices) |
| 811 | + return self.compute(xq, wq, x_scale, w_scale, m_indices) |
| 812 | + |
| 813 | + @property |
| 814 | + def name(self) -> str: |
| 815 | + return "deepgemm_stacked" |
| 816 | + |
| 817 | + @property |
| 818 | + def hip(self) -> bool: |
| 819 | + return False |
| 820 | + |
| 821 | + @property |
| 822 | + def cuda(self) -> bool: |
| 823 | + return DEEPGEMM_ENABLED |
| 824 | + |
| 825 | + |
715 | 826 | @register_quantize_op
|
716 | 827 | class BF16GroupedGemm(QuantizeOpBase):
|
717 | 828 | """
|
|
0 commit comments