@@ -27,7 +27,9 @@ class GatherScatterTests(unittest.TestCase):
27
27
"""Test Gathers."""
28
28
29
29
def test_gather_along_first_dim (self ) -> None :
30
- def _test_gather_along_first_dim (M : int , N : int , K : int ) -> None :
30
+ def _test_gather_along_first_dim (
31
+ M : int , N : int , K : int , compile : bool = False
32
+ ) -> None :
31
33
logger .info (f"Running test_gather_along_first_dim: { M = } , { N = } , { K = } " )
32
34
src = torch .randn ([M , K ], device = "cuda" , dtype = torch .bfloat16 ).abs ()
33
35
if M == N :
@@ -36,7 +38,10 @@ def _test_gather_along_first_dim(M: int, N: int, K: int) -> None:
36
38
indices = torch .randint (0 , M , [N ], device = "cuda" , dtype = torch .int32 )
37
39
38
40
def fn ():
39
- return torch .ops .fbgemm .gather_along_first_dim (src , indices )
41
+ op = torch .ops .fbgemm .gather_along_first_dim
42
+ if compile :
43
+ op = torch .compile (op , backend = "inductor" , fullgraph = True )
44
+ return op (src , indices )
40
45
41
46
def ref_fn ():
42
47
return torch .index_select (src , 0 , indices )
@@ -71,38 +76,41 @@ def ref_fn():
71
76
_test_gather_along_first_dim (255 , 129 , 2049 )
72
77
_test_gather_along_first_dim (255 , 129 , 2048 )
73
78
_test_gather_along_first_dim (1024 , 1024 , 1024 )
79
+ _test_gather_along_first_dim (1024 , 1024 , 1024 , compile = True )
74
80
75
81
def test_scatter_add_along_first_dim (self ) -> None :
76
- def _test_scatter_add_along_first_dim (M : int , N : int , K : int ) -> None :
82
+ def _test_scatter_add_along_first_dim (
83
+ M : int , N : int , K : int , compile : bool = False
84
+ ) -> None :
77
85
logger .info (f"Running test_scatter_add_along_first_dim: { M = } , { N = } , { K = } " )
78
86
src = torch .randn ([M , K ], device = "cuda" , dtype = torch .bfloat16 ).abs ()
79
87
dst = torch .randn ([N , K ], device = "cuda" , dtype = torch .bfloat16 ).abs ()
80
88
if M == N :
81
- indices = torch .randperm (N , device = "cuda" , dtype = torch .int32 )
89
+ indices_1d = torch .randperm (N , device = "cuda" , dtype = torch .int64 )
82
90
else :
83
- indices = torch .randint (0 , N , [M ], device = "cuda" , dtype = torch .int32 )
91
+ indices_1d = torch .randint (0 , N , [M ], device = "cuda" , dtype = torch .int64 )
84
92
85
- indices_int32 = indices .to (torch .int32 )
86
- indices_int64 = indices .to (torch .int64 ).unsqueeze (1 ).expand (- 1 , K )
93
+ indices_2d = indices_1d .to (torch .int64 ).unsqueeze (1 ).expand (- 1 , K )
87
94
88
95
test_dst = dst .clone ()
89
96
ref_dst = dst .clone ()
90
97
91
98
logger .info ("Running FBGMM" )
92
- torch .ops .fbgemm .scatter_add_along_first_dim (test_dst , src , indices_int32 )
99
+ torch .ops .fbgemm .scatter_add_along_first_dim (test_dst , src , indices_1d )
93
100
94
101
logger .info ("Running PyTorch" )
95
- ref_dst .scatter_add_ (0 , indices_int64 , src )
102
+ ref_dst .scatter_add_ (0 , indices_2d , src )
96
103
97
104
torch .testing .assert_close (test_dst , ref_dst , atol = 1e-3 , rtol = 2e-2 )
98
105
99
106
def fn ():
100
- torch .ops .fbgemm .scatter_add_along_first_dim (
101
- test_dst , src , indices_int32
102
- )
107
+ op = torch .ops .fbgemm .scatter_add_along_first_dim
108
+ if compile :
109
+ op = torch .compile (op , backend = "inductor" , fullgraph = True )
110
+ op (test_dst , src , indices_1d )
103
111
104
112
def ref_fn ():
105
- ref_dst .scatter_add_ (0 , indices_int64 , src )
113
+ ref_dst .scatter_add_ (0 , indices_2d , src )
106
114
107
115
# Load src, load dst, store dst. x3.
108
116
data_size_in_terabytes = N * K * 2 * 3 / 1e12
@@ -127,6 +135,7 @@ def ref_fn():
127
135
_test_scatter_add_along_first_dim (255 , 129 , 2049 )
128
136
_test_scatter_add_along_first_dim (255 , 129 , 2048 )
129
137
_test_scatter_add_along_first_dim (1024 , 1024 , 1024 )
138
+ _test_scatter_add_along_first_dim (1024 , 1024 , 1024 , compile = True )
130
139
131
140
132
141
if __name__ == "__main__" :
0 commit comments