Skip to content

Commit c7dbf19

Browse files
levendleefacebook-github-bot
authored andcommitted
Adds shapes information to enable torch.compile. (pytorch#3724)
Summary: X-link: facebookresearch/FBGEMM#807 Adds shape information to enable custom ops in torch.compile. Differential Revision: D69993984
1 parent 221c2aa commit c7dbf19

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ void scatter_add_along_first_dim(
2020
at::Tensor src,
2121
at::Tensor index);
2222

23+
at::Tensor gather_along_first_dim_meta(
24+
const at::Tensor& data,
25+
const at::Tensor& index) {
26+
int K = data.size(1);
27+
int N = index.size(0);
28+
at::Tensor output = at::empty({N, K}, data.options());
29+
return output;
30+
}
31+
32+
void scatter_add_along_first_dim_meta(
33+
const at::Tensor& /*dst*/,
34+
const at::Tensor& /*src*/,
35+
const at::Tensor& /*index*/) {
36+
return;
37+
}
38+
2339
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
2440
m.set_python_module("fbgemm_gpu.experimental.gen_ai.gather_scatter");
2541
m.def("gather_along_first_dim(Tensor Data, Tensor Index) -> Tensor");
@@ -32,6 +48,10 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
3248
m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim);
3349
}
3450

51+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
52+
m.impl("gather_along_first_dim", gather_along_first_dim_meta);
53+
m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim_meta);
54+
}
3555
#endif
3656

3757
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/test/gather_scatter/gather_scatter_test.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ class GatherScatterTests(unittest.TestCase):
2727
"""Test Gathers."""
2828

2929
def test_gather_along_first_dim(self) -> None:
30-
def _test_gather_along_first_dim(M: int, N: int, K: int) -> None:
30+
def _test_gather_along_first_dim(
31+
M: int, N: int, K: int, compile: bool = False
32+
) -> None:
3133
logger.info(f"Running test_gather_along_first_dim: {M=}, {N=}, {K=}")
3234
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
3335
if M == N:
@@ -36,7 +38,10 @@ def _test_gather_along_first_dim(M: int, N: int, K: int) -> None:
3638
indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int32)
3739

3840
def fn():
39-
return torch.ops.fbgemm.gather_along_first_dim(src, indices)
41+
op = torch.ops.fbgemm.gather_along_first_dim
42+
if compile:
43+
op = torch.compile(op, backend="inductor", fullgraph=True)
44+
return op(src, indices)
4045

4146
def ref_fn():
4247
return torch.index_select(src, 0, indices)
@@ -71,38 +76,41 @@ def ref_fn():
7176
_test_gather_along_first_dim(255, 129, 2049)
7277
_test_gather_along_first_dim(255, 129, 2048)
7378
_test_gather_along_first_dim(1024, 1024, 1024)
79+
_test_gather_along_first_dim(1024, 1024, 1024, compile=True)
7480

7581
def test_scatter_add_along_first_dim(self) -> None:
76-
def _test_scatter_add_along_first_dim(M: int, N: int, K: int) -> None:
82+
def _test_scatter_add_along_first_dim(
83+
M: int, N: int, K: int, compile: bool = False
84+
) -> None:
7785
logger.info(f"Running test_scatter_add_along_first_dim: {M=}, {N=}, {K=}")
7886
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
7987
dst = torch.randn([N, K], device="cuda", dtype=torch.bfloat16).abs()
8088
if M == N:
81-
indices = torch.randperm(N, device="cuda", dtype=torch.int32)
89+
indices_1d = torch.randperm(N, device="cuda", dtype=torch.int64)
8290
else:
83-
indices = torch.randint(0, N, [M], device="cuda", dtype=torch.int32)
91+
indices_1d = torch.randint(0, N, [M], device="cuda", dtype=torch.int64)
8492

85-
indices_int32 = indices.to(torch.int32)
86-
indices_int64 = indices.to(torch.int64).unsqueeze(1).expand(-1, K)
93+
indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)
8794

8895
test_dst = dst.clone()
8996
ref_dst = dst.clone()
9097

9198
logger.info("Running FBGMM")
92-
torch.ops.fbgemm.scatter_add_along_first_dim(test_dst, src, indices_int32)
99+
torch.ops.fbgemm.scatter_add_along_first_dim(test_dst, src, indices_1d)
93100

94101
logger.info("Running PyTorch")
95-
ref_dst.scatter_add_(0, indices_int64, src)
102+
ref_dst.scatter_add_(0, indices_2d, src)
96103

97104
torch.testing.assert_close(test_dst, ref_dst, atol=1e-3, rtol=2e-2)
98105

99106
def fn():
100-
torch.ops.fbgemm.scatter_add_along_first_dim(
101-
test_dst, src, indices_int32
102-
)
107+
op = torch.ops.fbgemm.scatter_add_along_first_dim
108+
if compile:
109+
op = torch.compile(op, backend="inductor", fullgraph=True)
110+
op(test_dst, src, indices_1d)
103111

104112
def ref_fn():
105-
ref_dst.scatter_add_(0, indices_int64, src)
113+
ref_dst.scatter_add_(0, indices_2d, src)
106114

107115
# Load src, load dst, store dst. x3.
108116
data_size_in_terabytes = N * K * 2 * 3 / 1e12
@@ -127,6 +135,7 @@ def ref_fn():
127135
_test_scatter_add_along_first_dim(255, 129, 2049)
128136
_test_scatter_add_along_first_dim(255, 129, 2048)
129137
_test_scatter_add_along_first_dim(1024, 1024, 1024)
138+
_test_scatter_add_along_first_dim(1024, 1024, 1024, compile=True)
130139

131140

132141
if __name__ == "__main__":

0 commit comments

Comments
 (0)