Skip to content

Commit 8b7fef5

Browse files
TroyGardenPaulZhang12
authored andcommitted
use new op in KTRegroupAsDict module (#2210)
Summary: Pull Request resolved: #2210 # context * the new op `permute_multi_embedding` outperforms the original op `permute_pooled_embs_auto_grad` * this diff makes the move to switch to the new op * benchmark results: D58907223 # benchmark * [traces](https://drive.google.com/drive/folders/1v_kD9n1jOkGUmYyix3-dUYiBDE_C3Hiv?usp=drive_link) * previous prod {F1747994738} * new prod {F1747994032} * metrics |Operator|GPU runtime|GPU memory|notes| |---|---|---|---|---| |**[previous prod] permute_pooled_embs**|4.9 ms|1.5 K|GPU-boudned, does **NOT** allow duplicates, PT2 non-compatible `pin_and_move`| |**[new prod] permute_multi_embedding**|2.0 ms|1.0 K|both CPU and GPU runtime/memory improved, **ALLOW** duplicates, PT2 friendly| Reviewed By: dstaay-fb Differential Revision: D53590566 fbshipit-source-id: 220878f99111fabc3de8a0ba83d319b36ee519f6
1 parent 06beaff commit 8b7fef5

File tree

2 files changed

+89
-44
lines changed

2 files changed

+89
-44
lines changed

torchrec/modules/regroup.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,19 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
1515
from torchrec.sparse.jagged_tensor import (
16-
_all_keys_used_once,
1716
_desugar_keyed_tensors,
18-
_remap_to_groups,
17+
_kt_regroup_arguments,
1918
KeyedTensor,
2019
)
2120

2221

2322
@torch.fx.wrap
24-
def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor:
25-
return torch.cat([kt.values() for kt in kts], dim=dim)
23+
def _get_kts_values(kts: List[KeyedTensor]) -> List[torch.Tensor]:
24+
return [kt.values() for kt in kts]
2625

2726

2827
@torch.fx.wrap
@@ -36,11 +35,34 @@ def _permuted_values(
3635

3736
@torch.fx.wrap
3837
def _build_dict(
39-
keys: List[str], values: torch.Tensor, splits: List[int], dim: int
38+
keys: List[str],
39+
values: Union[torch.Tensor, List[torch.Tensor]],
40+
splits: List[int],
41+
dim: int,
4042
) -> Dict[str, torch.Tensor]:
41-
return {
42-
key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim))
43-
}
43+
if isinstance(values, torch.Tensor):
44+
return dict(zip(keys, torch.split(values, splits, dim=dim)))
45+
else:
46+
return dict(zip(keys, values))
47+
48+
49+
@torch.fx.wrap
50+
def module_init(module: "KTRegroupAsDict", keyed_tensors: List[KeyedTensor]) -> None:
51+
assert len(keyed_tensors) > 0, "Empty list provided"
52+
assert all(
53+
kt.device() == keyed_tensors[0].device() for kt in keyed_tensors
54+
), "All inputs should be on the same device."
55+
module.device = keyed_tensors[0].device()
56+
assert all(
57+
kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors
58+
), "All inputs should have the same key_dim"
59+
module._dim = keyed_tensors[0].key_dim()
60+
61+
if module._dim == 1:
62+
module._init_fbgemm_regroup(keyed_tensors)
63+
else:
64+
module._init_regroup(keyed_tensors)
65+
module._is_inited = True
4466

4567

4668
class KTRegroupAsDict(torch.nn.Module):
@@ -76,27 +98,26 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
7698

7799
# cached values populated on first forward call
78100
self.device: Optional[torch.device] = None
79-
self._concat_dim: int = 1
101+
self._dim: int = 1
80102
self._use_fbgemm_regroup: bool = False
81103
self._splits: List[int] = []
82104
self._idx_key_pairs: List[Tuple[int, str]] = []
83-
self._permute_tensor: Optional[torch.Tensor] = None
84-
self._inv_permute_tensor: Optional[torch.Tensor] = None
85-
self._offsets_tensor: Optional[torch.Tensor] = None
86-
self._inv_offsets_tensor: Optional[torch.Tensor] = None
105+
self.register_buffer("_permutes", torch.empty(0), persistent=False)
106+
self.register_buffer("_in_shapes", torch.empty(0), persistent=False)
107+
self.register_buffer("_out_shapes", torch.empty(0), persistent=False)
108+
self._out_lengths: Optional[List[int]] = None
87109

88110
def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
89111
self._use_fbgemm_regroup = True
90112
keys, lengths, values = _desugar_keyed_tensors(kts)
91-
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
92-
keys, lengths, self._groups
113+
self._permutes, self._in_shapes, self._out_shapes, self._out_lengths = (
114+
_kt_regroup_arguments(
115+
values[0],
116+
keys,
117+
lengths,
118+
self._groups,
119+
)
93120
)
94-
# no need to pin_memory() or to(..., non_blocking=True) since occurs only once
95-
self._permute_tensor = permute.to(self.device)
96-
self._inv_permute_tensor = inv_permute.to(self.device)
97-
self._offsets_tensor = offsets.to(self.device)
98-
self._inv_offsets_tensor = inv_offsets.to(self.device)
99-
self._splits = splits
100121

101122
def _init_regroup(self, kts: List[KeyedTensor]) -> None:
102123
lengths = [kt.length_per_key() for kt in kts]
@@ -127,34 +148,19 @@ def _init_regroup(self, kts: List[KeyedTensor]) -> None:
127148

128149
def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
129150
if not self._is_inited:
130-
assert len(keyed_tensors) > 0, "Empty list provided"
131-
assert all(
132-
kt.device() == keyed_tensors[0].device() for kt in keyed_tensors
133-
), "All inputs should be on the same device."
134-
self.device = keyed_tensors[0].device()
135-
assert all(
136-
kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors
137-
), "All inputs should have the same key_dim"
138-
self._dim = keyed_tensors[0].key_dim()
139-
140-
if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1:
141-
self._init_fbgemm_regroup(keyed_tensors)
142-
else:
143-
self._init_regroup(keyed_tensors)
144-
self._is_inited = True
151+
module_init(self, keyed_tensors)
145152

146153
if self._use_fbgemm_regroup:
147-
values = _concat_values(keyed_tensors, self._dim)
148-
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
154+
values = _get_kts_values(keyed_tensors)
155+
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
149156
values,
150-
self._offsets_tensor,
151-
self._permute_tensor,
152-
self._inv_offsets_tensor,
153-
self._inv_permute_tensor,
157+
self._permutes,
158+
self._in_shapes,
159+
self._out_shapes,
160+
self._out_lengths,
154161
)
155162
else:
156163
permuted_values = _permuted_values(
157164
keyed_tensors, self._idx_key_pairs, self._dim
158165
)
159-
160166
return _build_dict(self._keys, permuted_values, self._splits, self._dim)

torchrec/modules/tests/test_regroup.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,52 @@ def setUp(self) -> None:
3333
self.keys = ["user", "object"]
3434
self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float()
3535

36+
def new_kts(self) -> None:
37+
self.kts = build_kts(
38+
dense_features=20,
39+
sparse_features=20,
40+
dim_dense=64,
41+
dim_sparse=128,
42+
batch_size=128,
43+
device=torch.device("cpu"),
44+
run_backward=True,
45+
)
46+
3647
def test_regroup_backward_skips_and_duplicates(self) -> None:
3748
groups = build_groups(
3849
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
3950
)
4051
assert _all_keys_used_once(self.kts, groups) is False
4152

4253
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
54+
55+
# first run
56+
tensor_groups = regroup_module(self.kts)
57+
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
58+
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
59+
actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad(
60+
loss, [self.kts[0].values(), self.kts[1].values()]
61+
)
62+
63+
# clear grads so can reuse inputs
64+
self.kts[0].values().grad = None
65+
self.kts[1].values().grad = None
66+
67+
tensor_groups = KeyedTensor.regroup_as_dict(
68+
keyed_tensors=self.kts, groups=groups, keys=self.keys
69+
)
70+
pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
71+
loss = torch.nn.functional.l1_loss(pred1, self.labels).sum()
72+
expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad(
73+
loss, [self.kts[0].values(), self.kts[1].values()]
74+
)
75+
76+
torch.allclose(pred0, pred1)
77+
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
78+
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
79+
80+
# second run
81+
self.new_kts()
4382
tensor_groups = regroup_module(self.kts)
4483
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
4584
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()

0 commit comments

Comments
 (0)