Skip to content

Commit 7337cb2

Browse files
Huanyu Hefacebook-github-bot
authored andcommitted
benchmark of fbgemm op - regroup_keyed_tensor
Differential Revision: D58907223
1 parent 1634b07 commit 7337cb2

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,19 @@ def permute_multi_embedding(
188188
return permuted_values
189189

190190

191+
@torch.fx.wrap
192+
def keyed_tensor_regroup(
193+
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
194+
) -> List[torch.Tensor]:
195+
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
196+
return torch.ops.fbgemm.regroup_keyed_tensor(
197+
values,
198+
keys,
199+
lengths,
200+
groups,
201+
)
202+
203+
191204
@torch.fx.wrap
192205
def _fbgemm_permute_pooled_embs(
193206
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
@@ -2708,11 +2721,7 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
27082721
def regroup(
27092722
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
27102723
) -> List[torch.Tensor]:
2711-
# Fast path, one-to-one correspondence between keyed_tensors and groups
2712-
if _all_keys_used_once(keyed_tensors, groups) is True:
2713-
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
2714-
else: # Fallback to slow path otherwise
2715-
return _regroup_keyed_tensors(keyed_tensors, groups)
2724+
return permute_multi_embedding(keyed_tensors, groups)
27162725

27172726
@staticmethod
27182727
def regroup_as_dict(

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
1919
from torchrec.modules.regroup import KTRegroupAsDict
2020
from torchrec.sparse.jagged_tensor import (
21+
_fbgemm_permute_pooled_embs,
2122
_regroup_keyed_tensors,
23+
keyed_tensor_regroup,
2224
KeyedJaggedTensor,
2325
KeyedTensor,
2426
permute_multi_embedding,
@@ -213,7 +215,7 @@ def main(
213215
).float()
214216
groups = build_groups(kts, n_groups, duplicates=duplicates)
215217
bench(
216-
"_regroup_keyed_tenors" + dup,
218+
"python_native" + dup,
217219
labels,
218220
batch_size,
219221
n_dense + n_sparse,
@@ -224,7 +226,7 @@ def main(
224226
profile,
225227
)
226228
bench(
227-
"KeyedTensor.regroup" + dup,
229+
"[Prod] KeyedTensor.regroup" + dup,
228230
labels,
229231
batch_size,
230232
n_dense + n_sparse,
@@ -235,7 +237,7 @@ def main(
235237
profile,
236238
)
237239
bench(
238-
"KTRegroupAsDict" + dup,
240+
"[Module] KTRegroupAsDict" + dup,
239241
labels,
240242
batch_size,
241243
n_dense + n_sparse,
@@ -248,7 +250,7 @@ def main(
248250
profile,
249251
)
250252
bench(
251-
"permute_multi_embs" + dup,
253+
"[2 Ops] permute_multi_embs" + dup,
252254
labels,
253255
batch_size,
254256
n_dense + n_sparse,
@@ -258,6 +260,29 @@ def main(
258260
{"keyed_tensors": kts, "groups": groups},
259261
profile,
260262
)
263+
bench(
264+
"[1 Op] KT_regroup" + dup,
265+
labels,
266+
batch_size,
267+
n_dense + n_sparse,
268+
device_type,
269+
run_backward,
270+
keyed_tensor_regroup,
271+
{"keyed_tensors": kts, "groups": groups},
272+
profile,
273+
)
274+
if not duplicates:
275+
bench(
276+
"[Old Prod] permute_pooled_embs" + dup,
277+
labels,
278+
batch_size,
279+
n_dense + n_sparse,
280+
device_type,
281+
run_backward,
282+
_fbgemm_permute_pooled_embs,
283+
{"keyed_tensors": kts, "groups": groups},
284+
profile,
285+
)
261286

262287

263288
if __name__ == "__main__":

0 commit comments

Comments
 (0)