Skip to content

Commit b50f026

Browse files
levendleefacebook-github-bot
authored andcommitted
Moves utility functions into a standalone file. (pytorch#3671)
Summary: Pull Request resolved: pytorch#3671 X-link: facebookresearch/FBGEMM#749 Moves functions to better modularize code. facebookresearch/FBGEMM#749 pytorch#3671 Reviewed By: jianyuh Differential Revision: D69377391
1 parent 3182ea5 commit b50f026

File tree

2 files changed

+131
-116
lines changed

2 files changed

+131
-116
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 4 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-unsafe
88
import logging
9-
import sys
109
from typing import List, Optional, Tuple, Union
1110

1211
import torch
@@ -18,6 +17,10 @@
1817
early_config_prune,
1918
estimate_matmul_time,
2019
)
20+
from fbgemm_gpu.experimental.gemm.triton_gemm.utils import (
21+
map_dtype_to_triton,
22+
TmaAutoTuneHelper,
23+
)
2124
from torch._tensor import Tensor
2225

2326
from triton import Config # @manual
@@ -59,28 +62,6 @@ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper
5962
return tl_reinterpret(tensor, dtype=dtype)
6063

6164

62-
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
63-
"""
64-
Maps torch dtype to triton dtype.
65-
66-
Args:
67-
dtype (torch.dtype): input dtype.
68-
69-
Returns:
70-
tl.dtype: triton dtype.
71-
"""
72-
if dtype == torch.float16:
73-
return tl.float16
74-
elif dtype == torch.bfloat16:
75-
return tl.bfloat16
76-
elif dtype == torch.float32:
77-
return tl.float32
78-
elif dtype == torch.int32:
79-
return tl.int32
80-
else:
81-
raise ValueError(f"Unsupported dtype {dtype}")
82-
83-
8465
def init_to_zero(name):
8566
return lambda nargs: nargs[name].zero_()
8667

@@ -1125,99 +1106,6 @@ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
11251106
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
11261107

11271108

1128-
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
1129-
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
1130-
1131-
if HAS_TMA_DESC:
1132-
print(
1133-
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
1134-
file=sys.stderr,
1135-
)
1136-
else:
1137-
print(
1138-
"TMA benchmarks will be running without grid constant TMA descriptor.",
1139-
file=sys.stderr,
1140-
)
1141-
1142-
1143-
class TmaAutoTuneHelper:
1144-
1145-
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
1146-
class KernelParamWrapper:
1147-
def __init__(self, desc):
1148-
self.desc = desc
1149-
1150-
def tma_desc_cpu_ptr(self):
1151-
return self.desc.data_ptr()
1152-
1153-
TMA_SIZE = 128
1154-
1155-
def __init__(self):
1156-
self.fill_1d_tma_descriptor_inner = (
1157-
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
1158-
)
1159-
self.fill_2d_tma_descriptor_inner = (
1160-
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
1161-
)
1162-
if HAS_TMA_DESC:
1163-
self.descriptors = {}
1164-
else:
1165-
self.cuda_descriptors = {}
1166-
1167-
# Call this method outside of the lambda function for grid size
1168-
def init_tma_descriptor(self, name):
1169-
if HAS_TMA_DESC:
1170-
self.descriptors[name] = torch.empty(
1171-
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
1172-
)
1173-
else:
1174-
self.cuda_descriptors[name] = torch.empty(
1175-
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
1176-
)
1177-
1178-
# Call this method inside the lambda function for grid size
1179-
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
1180-
if HAS_TMA_DESC:
1181-
desc_x = self.descriptors[name]
1182-
assert desc_x.data_ptr() % 64 == 0
1183-
self.fill_1d_tma_descriptor_inner(
1184-
ptr, dim, block_dim, element_size, desc_x.data_ptr()
1185-
)
1186-
else:
1187-
desc_x = self.cuda_descriptors[name]
1188-
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
1189-
self.fill_1d_tma_descriptor_inner(
1190-
ptr, dim, block_dim, element_size, buf_x.data_ptr()
1191-
)
1192-
desc_x.copy_(buf_x, non_blocking=True)
1193-
1194-
# Call this method inside the lambda function for grid size
1195-
def fill_2d_tma_descriptor(
1196-
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
1197-
):
1198-
if HAS_TMA_DESC:
1199-
desc_x = self.descriptors[name]
1200-
assert desc_x.data_ptr() % 64 == 0
1201-
self.fill_2d_tma_descriptor_inner(
1202-
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
1203-
)
1204-
else:
1205-
desc_x = self.cuda_descriptors[name]
1206-
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
1207-
self.fill_2d_tma_descriptor_inner(
1208-
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
1209-
)
1210-
desc_x.copy_(buf_x, non_blocking=True)
1211-
1212-
def get_tma_descriptor_kernel_param(self, name):
1213-
if HAS_TMA_DESC:
1214-
assert self.descriptors[name] is not None
1215-
return self.KernelParamWrapper(self.descriptors[name])
1216-
else:
1217-
assert self.cuda_descriptors[name] is not None
1218-
return self.cuda_descriptors[name]
1219-
1220-
12211109
@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
12221110
def matmul_fp8_row(
12231111
a: torch.Tensor,
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
import torch
10+
import triton # @manual
11+
12+
import triton.language as tl # @manual
13+
14+
15+
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
16+
"""
17+
Maps torch dtype to triton dtype.
18+
19+
Args:
20+
dtype (torch.dtype): input dtype.
21+
22+
Returns:
23+
tl.dtype: triton dtype.
24+
"""
25+
if dtype == torch.float16:
26+
return tl.float16
27+
elif dtype == torch.bfloat16:
28+
return tl.bfloat16
29+
elif dtype == torch.float32:
30+
return tl.float32
31+
elif dtype == torch.int32:
32+
return tl.int32
33+
else:
34+
raise ValueError(f"Unsupported dtype {dtype}")
35+
36+
37+
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
38+
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
39+
40+
if HAS_TMA_DESC:
41+
print(
42+
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
43+
file=sys.stderr,
44+
)
45+
else:
46+
print(
47+
"TMA benchmarks will be running without grid constant TMA descriptor.",
48+
file=sys.stderr,
49+
)
50+
51+
52+
class TmaAutoTuneHelper:
53+
54+
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
55+
class KernelParamWrapper:
56+
def __init__(self, desc):
57+
self.desc = desc
58+
59+
def tma_desc_cpu_ptr(self):
60+
return self.desc.data_ptr()
61+
62+
TMA_SIZE = 128
63+
64+
def __init__(self):
65+
self.fill_1d_tma_descriptor_inner = (
66+
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
67+
)
68+
self.fill_2d_tma_descriptor_inner = (
69+
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
70+
)
71+
if HAS_TMA_DESC:
72+
self.descriptors = {}
73+
else:
74+
self.cuda_descriptors = {}
75+
76+
# Call this method outside of the lambda function for grid size
77+
def init_tma_descriptor(self, name):
78+
if HAS_TMA_DESC:
79+
self.descriptors[name] = torch.empty(
80+
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
81+
)
82+
else:
83+
self.cuda_descriptors[name] = torch.empty(
84+
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
85+
)
86+
87+
# Call this method inside the lambda function for grid size
88+
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
89+
if HAS_TMA_DESC:
90+
desc_x = self.descriptors[name]
91+
assert desc_x.data_ptr() % 64 == 0
92+
self.fill_1d_tma_descriptor_inner(
93+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
94+
)
95+
else:
96+
desc_x = self.cuda_descriptors[name]
97+
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
98+
self.fill_1d_tma_descriptor_inner(
99+
ptr, dim, block_dim, element_size, buf_x.data_ptr()
100+
)
101+
desc_x.copy_(buf_x, non_blocking=True)
102+
103+
# Call this method inside the lambda function for grid size
104+
def fill_2d_tma_descriptor(
105+
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
106+
):
107+
if HAS_TMA_DESC:
108+
desc_x = self.descriptors[name]
109+
assert desc_x.data_ptr() % 64 == 0
110+
self.fill_2d_tma_descriptor_inner(
111+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
112+
)
113+
else:
114+
desc_x = self.cuda_descriptors[name]
115+
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
116+
self.fill_2d_tma_descriptor_inner(
117+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
118+
)
119+
desc_x.copy_(buf_x, non_blocking=True)
120+
121+
def get_tma_descriptor_kernel_param(self, name):
122+
if HAS_TMA_DESC:
123+
assert self.descriptors[name] is not None
124+
return self.KernelParamWrapper(self.descriptors[name])
125+
else:
126+
assert self.cuda_descriptors[name] is not None
127+
return self.cuda_descriptors[name]

0 commit comments

Comments
 (0)