9
9
10
10
#!/usr/bin/env python3
11
11
12
- from typing import Dict , List , Optional , Tuple
12
+ from typing import Dict , List , Optional , Tuple , Union
13
13
14
14
import torch
15
- from torchrec .sparse .jagged_tensor import (
16
- _all_keys_used_once ,
17
- _desugar_keyed_tensors ,
18
- _remap_to_groups ,
19
- KeyedTensor ,
20
- )
15
+ from torchrec .sparse .jagged_tensor import _desugar_keyed_tensors , KeyedTensor
21
16
22
17
23
18
@torch .fx .wrap
24
- def _concat_values (kts : List [KeyedTensor ], dim : int ) -> torch .Tensor :
25
- return torch . cat ( [kt .values () for kt in kts ], dim = dim )
19
+ def _get_kts_values (kts : List [KeyedTensor ]) -> List [ torch .Tensor ] :
20
+ return [kt .values () for kt in kts ]
26
21
27
22
28
23
@torch .fx .wrap
@@ -36,11 +31,15 @@ def _permuted_values(
36
31
37
32
@torch .fx .wrap
38
33
def _build_dict (
39
- keys : List [str ], values : torch .Tensor , splits : List [int ], dim : int
34
+ keys : List [str ],
35
+ values : Union [torch .Tensor , List [torch .Tensor ]],
36
+ splits : List [int ],
37
+ dim : int ,
40
38
) -> Dict [str , torch .Tensor ]:
41
- return {
42
- key : tensor for key , tensor in zip (keys , torch .split (values , splits , dim = dim ))
43
- }
39
+ if isinstance (values , torch .Tensor ):
40
+ return {key : st for key , st in zip (keys , torch .split (values , splits , dim = dim ))}
41
+ else :
42
+ return {key : tensor for key , tensor in zip (keys , values )}
44
43
45
44
46
45
class KTRegroupAsDict (torch .nn .Module ):
@@ -80,23 +79,22 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
80
79
self ._use_fbgemm_regroup : bool = False
81
80
self ._splits : List [int ] = []
82
81
self ._idx_key_pairs : List [Tuple [int , str ]] = []
83
- self ._permute_tensor : Optional [torch .Tensor ] = None
84
- self ._inv_permute_tensor : Optional [torch .Tensor ] = None
85
- self ._offsets_tensor : Optional [torch .Tensor ] = None
86
- self ._inv_offsets_tensor : Optional [torch . Tensor ] = None
82
+ self ._permutes : Optional [torch .Tensor ] = None
83
+ self ._in_shapes : Optional [torch .Tensor ] = None
84
+ self ._out_shapes : Optional [torch .Tensor ] = None
85
+ self ._out_lengths : Optional [List [ int ] ] = None
87
86
88
87
def _init_fbgemm_regroup (self , kts : List [KeyedTensor ]) -> None :
89
88
self ._use_fbgemm_regroup = True
90
89
keys , lengths , values = _desugar_keyed_tensors (kts )
91
- permute , inv_permute , offsets , inv_offsets , splits = _remap_to_groups (
92
- keys , lengths , self ._groups
90
+ self ._permutes , self ._in_shapes , self ._out_shapes , self ._out_lengths = (
91
+ torch .ops .fbgemm .kt_regroup_permutes (
92
+ values [0 ],
93
+ keys ,
94
+ lengths ,
95
+ self ._groups ,
96
+ )
93
97
)
94
- # no need to pin_memory() or to(..., non_blocking=True) since occurs only once
95
- self ._permute_tensor = permute .to (self .device )
96
- self ._inv_permute_tensor = inv_permute .to (self .device )
97
- self ._offsets_tensor = offsets .to (self .device )
98
- self ._inv_offsets_tensor = inv_offsets .to (self .device )
99
- self ._splits = splits
100
98
101
99
def _init_regroup (self , kts : List [KeyedTensor ]) -> None :
102
100
lengths = [kt .length_per_key () for kt in kts ]
@@ -137,24 +135,24 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
137
135
), "All inputs should have the same key_dim"
138
136
self ._dim = keyed_tensors [0 ].key_dim ()
139
137
140
- if _all_keys_used_once ( keyed_tensors , self . _groups ) and self ._dim == 1 :
138
+ if self ._dim == 1 :
141
139
self ._init_fbgemm_regroup (keyed_tensors )
142
140
else :
143
141
self ._init_regroup (keyed_tensors )
144
142
self ._is_inited = True
145
143
146
144
if self ._use_fbgemm_regroup :
147
- values = _concat_values (keyed_tensors , self . _dim )
148
- permuted_values = torch .ops .fbgemm .permute_pooled_embs_auto_grad (
145
+ values = _get_kts_values (keyed_tensors )
146
+ permuted_values = torch .ops .fbgemm .permute_multi_embedding (
149
147
values ,
150
- self ._offsets_tensor ,
151
- self ._permute_tensor ,
152
- self ._inv_offsets_tensor ,
153
- self ._inv_permute_tensor ,
148
+ self ._permutes ,
149
+ self ._in_shapes ,
150
+ self ._out_shapes ,
151
+ self ._out_lengths ,
154
152
)
153
+ # return {key: tensor for key, tensor in zip(self._keys, permuted_values)}
155
154
else :
156
155
permuted_values = _permuted_values (
157
156
keyed_tensors , self ._idx_key_pairs , self ._dim
158
157
)
159
-
160
158
return _build_dict (self ._keys , permuted_values , self ._splits , self ._dim )
0 commit comments