Skip to content

Commit 78be7d3

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
use new op in KTRegroupAsDict module (pytorch#2210)
Summary: Pull Request resolved: pytorch#2210 # context * adding PackedTensorAccessor for passing the index tensor to kernel * GPU trace reading slows down from 2.20ms to 2.26ms # traces * previous ~4.90s {F1747994738} * after ~2.00ms {F1747994032} Differential Revision: D53590566
1 parent 7337cb2 commit 78be7d3

File tree

1 file changed

+31
-33
lines changed

1 file changed

+31
-33
lines changed

torchrec/modules/regroup.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,15 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
15-
from torchrec.sparse.jagged_tensor import (
16-
_all_keys_used_once,
17-
_desugar_keyed_tensors,
18-
_remap_to_groups,
19-
KeyedTensor,
20-
)
15+
from torchrec.sparse.jagged_tensor import _desugar_keyed_tensors, KeyedTensor
2116

2217

2318
@torch.fx.wrap
24-
def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor:
25-
return torch.cat([kt.values() for kt in kts], dim=dim)
19+
def _get_kts_values(kts: List[KeyedTensor]) -> List[torch.Tensor]:
20+
return [kt.values() for kt in kts]
2621

2722

2823
@torch.fx.wrap
@@ -36,11 +31,15 @@ def _permuted_values(
3631

3732
@torch.fx.wrap
3833
def _build_dict(
39-
keys: List[str], values: torch.Tensor, splits: List[int], dim: int
34+
keys: List[str],
35+
values: Union[torch.Tensor, List[torch.Tensor]],
36+
splits: List[int],
37+
dim: int,
4038
) -> Dict[str, torch.Tensor]:
41-
return {
42-
key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim))
43-
}
39+
if isinstance(values, torch.Tensor):
40+
return {key: st for key, st in zip(keys, torch.split(values, splits, dim=dim))}
41+
else:
42+
return {key: tensor for key, tensor in zip(keys, values)}
4443

4544

4645
class KTRegroupAsDict(torch.nn.Module):
@@ -80,23 +79,22 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
8079
self._use_fbgemm_regroup: bool = False
8180
self._splits: List[int] = []
8281
self._idx_key_pairs: List[Tuple[int, str]] = []
83-
self._permute_tensor: Optional[torch.Tensor] = None
84-
self._inv_permute_tensor: Optional[torch.Tensor] = None
85-
self._offsets_tensor: Optional[torch.Tensor] = None
86-
self._inv_offsets_tensor: Optional[torch.Tensor] = None
82+
self._permutes: Optional[torch.Tensor] = None
83+
self._in_shapes: Optional[torch.Tensor] = None
84+
self._out_shapes: Optional[torch.Tensor] = None
85+
self._out_lengths: Optional[List[int]] = None
8786

8887
def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
8988
self._use_fbgemm_regroup = True
9089
keys, lengths, values = _desugar_keyed_tensors(kts)
91-
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
92-
keys, lengths, self._groups
90+
self._permutes, self._in_shapes, self._out_shapes, self._out_lengths = (
91+
torch.ops.fbgemm.kt_regroup_permutes(
92+
values[0],
93+
keys,
94+
lengths,
95+
self._groups,
96+
)
9397
)
94-
# no need to pin_memory() or to(..., non_blocking=True) since occurs only once
95-
self._permute_tensor = permute.to(self.device)
96-
self._inv_permute_tensor = inv_permute.to(self.device)
97-
self._offsets_tensor = offsets.to(self.device)
98-
self._inv_offsets_tensor = inv_offsets.to(self.device)
99-
self._splits = splits
10098

10199
def _init_regroup(self, kts: List[KeyedTensor]) -> None:
102100
lengths = [kt.length_per_key() for kt in kts]
@@ -137,24 +135,24 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
137135
), "All inputs should have the same key_dim"
138136
self._dim = keyed_tensors[0].key_dim()
139137

140-
if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1:
138+
if self._dim == 1:
141139
self._init_fbgemm_regroup(keyed_tensors)
142140
else:
143141
self._init_regroup(keyed_tensors)
144142
self._is_inited = True
145143

146144
if self._use_fbgemm_regroup:
147-
values = _concat_values(keyed_tensors, self._dim)
148-
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
145+
values = _get_kts_values(keyed_tensors)
146+
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
149147
values,
150-
self._offsets_tensor,
151-
self._permute_tensor,
152-
self._inv_offsets_tensor,
153-
self._inv_permute_tensor,
148+
self._permutes,
149+
self._in_shapes,
150+
self._out_shapes,
151+
self._out_lengths,
154152
)
153+
# return {key: tensor for key, tensor in zip(self._keys, permuted_values)}
155154
else:
156155
permuted_values = _permuted_values(
157156
keyed_tensors, self._idx_key_pairs, self._dim
158157
)
159-
160158
return _build_dict(self._keys, permuted_values, self._splits, self._dim)

0 commit comments

Comments
 (0)