|
6 | 6 |
|
7 | 7 | # pyre-unsafe
|
8 | 8 | import logging
|
9 |
| -import sys |
10 | 9 | from typing import List, Optional, Tuple, Union
|
11 | 10 |
|
12 | 11 | import torch
|
|
18 | 17 | early_config_prune,
|
19 | 18 | estimate_matmul_time,
|
20 | 19 | )
|
| 20 | +from fbgemm_gpu.experimental.gemm.triton_gemm.utils import ( |
| 21 | + map_dtype_to_triton, |
| 22 | + TmaAutoTuneHelper, |
| 23 | +) |
21 | 24 | from torch._tensor import Tensor
|
22 | 25 |
|
23 | 26 | from triton import Config # @manual
|
@@ -59,28 +62,6 @@ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper
|
59 | 62 | return tl_reinterpret(tensor, dtype=dtype)
|
60 | 63 |
|
61 | 64 |
|
62 |
| -def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: |
63 |
| - """ |
64 |
| - Maps torch dtype to triton dtype. |
65 |
| -
|
66 |
| - Args: |
67 |
| - dtype (torch.dtype): input dtype. |
68 |
| -
|
69 |
| - Returns: |
70 |
| - tl.dtype: triton dtype. |
71 |
| - """ |
72 |
| - if dtype == torch.float16: |
73 |
| - return tl.float16 |
74 |
| - elif dtype == torch.bfloat16: |
75 |
| - return tl.bfloat16 |
76 |
| - elif dtype == torch.float32: |
77 |
| - return tl.float32 |
78 |
| - elif dtype == torch.int32: |
79 |
| - return tl.int32 |
80 |
| - else: |
81 |
| - raise ValueError(f"Unsupported dtype {dtype}") |
82 |
| - |
83 |
| - |
84 | 65 | def init_to_zero(name):
|
85 | 66 | return lambda nargs: nargs[name].zero_()
|
86 | 67 |
|
@@ -1125,99 +1106,6 @@ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
|
1125 | 1106 | tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
|
1126 | 1107 |
|
1127 | 1108 |
|
1128 |
| -# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). |
1129 |
| -HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) |
1130 |
| - |
1131 |
| -if HAS_TMA_DESC: |
1132 |
| - print( |
1133 |
| - "TMA benchmarks will be running with experimental grid constant TMA descriptor.", |
1134 |
| - file=sys.stderr, |
1135 |
| - ) |
1136 |
| -else: |
1137 |
| - print( |
1138 |
| - "TMA benchmarks will be running without grid constant TMA descriptor.", |
1139 |
| - file=sys.stderr, |
1140 |
| - ) |
1141 |
| - |
1142 |
| - |
1143 |
| -class TmaAutoTuneHelper: |
1144 |
| - |
1145 |
| - # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 |
1146 |
| - class KernelParamWrapper: |
1147 |
| - def __init__(self, desc): |
1148 |
| - self.desc = desc |
1149 |
| - |
1150 |
| - def tma_desc_cpu_ptr(self): |
1151 |
| - return self.desc.data_ptr() |
1152 |
| - |
1153 |
| - TMA_SIZE = 128 |
1154 |
| - |
1155 |
| - def __init__(self): |
1156 |
| - self.fill_1d_tma_descriptor_inner = ( |
1157 |
| - triton.runtime.driver.active.utils.fill_1d_tma_descriptor |
1158 |
| - ) |
1159 |
| - self.fill_2d_tma_descriptor_inner = ( |
1160 |
| - triton.runtime.driver.active.utils.fill_2d_tma_descriptor |
1161 |
| - ) |
1162 |
| - if HAS_TMA_DESC: |
1163 |
| - self.descriptors = {} |
1164 |
| - else: |
1165 |
| - self.cuda_descriptors = {} |
1166 |
| - |
1167 |
| - # Call this method outside of the lambda function for grid size |
1168 |
| - def init_tma_descriptor(self, name): |
1169 |
| - if HAS_TMA_DESC: |
1170 |
| - self.descriptors[name] = torch.empty( |
1171 |
| - TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 |
1172 |
| - ) |
1173 |
| - else: |
1174 |
| - self.cuda_descriptors[name] = torch.empty( |
1175 |
| - TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 |
1176 |
| - ) |
1177 |
| - |
1178 |
| - # Call this method inside the lambda function for grid size |
1179 |
| - def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): |
1180 |
| - if HAS_TMA_DESC: |
1181 |
| - desc_x = self.descriptors[name] |
1182 |
| - assert desc_x.data_ptr() % 64 == 0 |
1183 |
| - self.fill_1d_tma_descriptor_inner( |
1184 |
| - ptr, dim, block_dim, element_size, desc_x.data_ptr() |
1185 |
| - ) |
1186 |
| - else: |
1187 |
| - desc_x = self.cuda_descriptors[name] |
1188 |
| - buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) |
1189 |
| - self.fill_1d_tma_descriptor_inner( |
1190 |
| - ptr, dim, block_dim, element_size, buf_x.data_ptr() |
1191 |
| - ) |
1192 |
| - desc_x.copy_(buf_x, non_blocking=True) |
1193 |
| - |
1194 |
| - # Call this method inside the lambda function for grid size |
1195 |
| - def fill_2d_tma_descriptor( |
1196 |
| - self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size |
1197 |
| - ): |
1198 |
| - if HAS_TMA_DESC: |
1199 |
| - desc_x = self.descriptors[name] |
1200 |
| - assert desc_x.data_ptr() % 64 == 0 |
1201 |
| - self.fill_2d_tma_descriptor_inner( |
1202 |
| - ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() |
1203 |
| - ) |
1204 |
| - else: |
1205 |
| - desc_x = self.cuda_descriptors[name] |
1206 |
| - buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) |
1207 |
| - self.fill_2d_tma_descriptor_inner( |
1208 |
| - ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() |
1209 |
| - ) |
1210 |
| - desc_x.copy_(buf_x, non_blocking=True) |
1211 |
| - |
1212 |
| - def get_tma_descriptor_kernel_param(self, name): |
1213 |
| - if HAS_TMA_DESC: |
1214 |
| - assert self.descriptors[name] is not None |
1215 |
| - return self.KernelParamWrapper(self.descriptors[name]) |
1216 |
| - else: |
1217 |
| - assert self.cuda_descriptors[name] is not None |
1218 |
| - return self.cuda_descriptors[name] |
1219 |
| - |
1220 |
| - |
1221 | 1109 | @torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
|
1222 | 1110 | def matmul_fp8_row(
|
1223 | 1111 | a: torch.Tensor,
|
|
0 commit comments