Skip to content

Commit 419ce8c

Browse files
q10facebook-github-bot
authored andcommitted
Re-organize SLL ops, pt 6 (#3647)
Summary: Pull Request resolved: #3647 X-link: facebookresearch/FBGEMM#722 - Re-organize `dense_jagged_cat_jagged_out` Reviewed By: brad-mengchi Differential Revision: D68936183
1 parent 79fcd5b commit 419ce8c

File tree

7 files changed

+651
-620
lines changed

7 files changed

+651
-620
lines changed

fbgemm_gpu/fbgemm_gpu/sll/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@
3434
)
3535

3636
from fbgemm_gpu.sll.triton_sll import ( # noqa F401
37-
array_jagged_bmm_jagged_out,
38-
dense_jagged_cat_jagged_out,
3937
jagged2_to_padded_dense,
40-
# jagged_dense_bmm,
4138
jagged_dense_elementwise_mul_jagged_out,
42-
jagged_jagged_bmm_jagged_out,
4339
triton_jagged_self_substraction_jagged_out,
4440
)
4541

@@ -269,9 +265,6 @@
269265

270266
# pyre-ignore[5]
271267
sll_gpu_registrations = {
272-
"sll_dense_jagged_cat_jagged_out": {
273-
"CUDA": dense_jagged_cat_jagged_out,
274-
},
275268
"sll_jagged_self_substraction_jagged_out": {
276269
"CUDA": triton_jagged_self_substraction_jagged_out,
277270
},
@@ -283,14 +276,6 @@
283276
"CUDA": jagged_dense_elementwise_mul_jagged_out,
284277
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
285278
},
286-
"sll_array_jagged_bmm_jagged_out": {
287-
"CUDA": array_jagged_bmm_jagged_out,
288-
"AutogradCUDA": array_jagged_bmm_jagged_out,
289-
},
290-
"sll_jagged_jagged_bmm_jagged_out": {
291-
"CUDA": jagged_jagged_bmm_jagged_out,
292-
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
293-
},
294279
}
295280

296281
for op_name, dispatches in sll_cpu_registrations.items():

fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,26 @@
77

88
# pyre-strict
99

10+
from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
11+
dense_jagged_cat_jagged_out,
12+
)
13+
1014
from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
1115
jagged_dense_bmm,
1216
jagged_jagged_bmm,
1317
JaggedDenseBmm, # noqa F401
1418
JaggedJaggedBmm, # noqa F401
1519
)
1620

21+
from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
22+
array_jagged_bmm_jagged_out,
23+
ArrayJaggedBmmNopadding, # noqa F401
24+
jagged_jagged_bmm_jagged_out,
25+
JaggedJaggedBmmNoPadding, # noqa F401
26+
triton_array_jagged_bmm_jagged_out, # noqa F401
27+
triton_jagged_jagged_bmm_jagged_out, # noqa F401
28+
)
29+
1730
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
1831
jagged_dense_elementwise_add,
1932
JaggedDenseAdd, # noqa F401
@@ -43,6 +56,9 @@
4356

4457
# pyre-ignore[5]
4558
op_registrations = {
59+
"sll_dense_jagged_cat_jagged_out": {
60+
"CUDA": dense_jagged_cat_jagged_out,
61+
},
4662
"sll_jagged_dense_bmm": {
4763
"CUDA": jagged_dense_bmm,
4864
"AutogradCUDA": jagged_dense_bmm,
@@ -51,6 +67,14 @@
5167
"CUDA": jagged_jagged_bmm,
5268
"AutogradCUDA": jagged_jagged_bmm,
5369
},
70+
"sll_array_jagged_bmm_jagged_out": {
71+
"CUDA": array_jagged_bmm_jagged_out,
72+
"AutogradCUDA": array_jagged_bmm_jagged_out,
73+
},
74+
"sll_jagged_jagged_bmm_jagged_out": {
75+
"CUDA": jagged_jagged_bmm_jagged_out,
76+
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
77+
},
5478
"sll_jagged_softmax": {
5579
"CUDA": jagged_softmax,
5680
"AutogradCUDA": jagged_softmax,
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import triton
11+
import triton.language as tl
12+
13+
14+
@triton.jit
15+
def dense_jagged_cat_jagged_out_kernel(
16+
a_ptr, # dense
17+
b_ptr, # jagged
18+
c_ptr, # jagged
19+
b_offsets_ptr,
20+
c_offsets_ptr,
21+
max_seq_len,
22+
BLOCK_SIZE: tl.constexpr,
23+
):
24+
pid_batch = tl.program_id(0)
25+
b_start = tl.load(b_offsets_ptr + pid_batch)
26+
b_end = tl.load(b_offsets_ptr + pid_batch + 1)
27+
c_start = b_start + pid_batch
28+
N = b_end - b_start
29+
N = tl.minimum(N, max_seq_len)
30+
31+
a = tl.load(a_ptr + pid_batch)
32+
tl.store(c_ptr + c_start, a)
33+
34+
offs_k = tl.arange(0, BLOCK_SIZE)
35+
for k in range(0, N, BLOCK_SIZE):
36+
b_offset = k + offs_k
37+
b_ptrs = b_ptr + b_start + b_offset
38+
b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
39+
tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
40+
tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)
41+
42+
43+
def dense_jagged_cat_jagged_out(
44+
a: torch.Tensor,
45+
b: torch.Tensor,
46+
b_offsets: torch.Tensor,
47+
max_seq_len: int,
48+
):
49+
assert a.is_contiguous()
50+
assert b.is_contiguous()
51+
assert b_offsets.is_contiguous()
52+
B = a.size(0)
53+
BLOCK_SIZE = 128
54+
c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
55+
c_offsets = torch.empty(
56+
b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
57+
) # B + 1
58+
59+
dense_jagged_cat_jagged_out_kernel[(B,)](
60+
a,
61+
b,
62+
c,
63+
b_offsets,
64+
c_offsets,
65+
max_seq_len,
66+
# pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
67+
BLOCK_SIZE,
68+
)
69+
70+
c_offsets[-1] = b_offsets[-1] + B
71+
72+
return c, c_offsets

0 commit comments

Comments
 (0)