9
9
10
10
#!/usr/bin/env python3
11
11
12
- from typing import Dict , List , Optional , Tuple
12
+ from typing import Dict , List , Optional , Tuple , Union
13
13
14
14
import torch
15
15
from torchrec .sparse .jagged_tensor import (
16
- _all_keys_used_once ,
17
16
_desugar_keyed_tensors ,
18
- _remap_to_groups ,
17
+ _kt_regroup_arguments ,
19
18
KeyedTensor ,
20
19
)
21
20
22
21
23
22
@torch .fx .wrap
24
- def _concat_values (kts : List [KeyedTensor ], dim : int ) -> torch .Tensor :
25
- return torch . cat ( [kt .values () for kt in kts ], dim = dim )
23
+ def _get_kts_values (kts : List [KeyedTensor ]) -> List [ torch .Tensor ] :
24
+ return [kt .values () for kt in kts ]
26
25
27
26
28
27
@torch .fx .wrap
@@ -36,11 +35,34 @@ def _permuted_values(
36
35
37
36
@torch .fx .wrap
38
37
def _build_dict (
39
- keys : List [str ], values : torch .Tensor , splits : List [int ], dim : int
38
+ keys : List [str ],
39
+ values : Union [torch .Tensor , List [torch .Tensor ]],
40
+ splits : List [int ],
41
+ dim : int ,
40
42
) -> Dict [str , torch .Tensor ]:
41
- return {
42
- key : tensor for key , tensor in zip (keys , torch .split (values , splits , dim = dim ))
43
- }
43
+ if isinstance (values , torch .Tensor ):
44
+ return dict (zip (keys , torch .split (values , splits , dim = dim )))
45
+ else :
46
+ return dict (zip (keys , values ))
47
+
48
+
49
+ @torch .fx .wrap
50
+ def module_init (module : "KTRegroupAsDict" , keyed_tensors : List [KeyedTensor ]) -> None :
51
+ assert len (keyed_tensors ) > 0 , "Empty list provided"
52
+ assert all (
53
+ kt .device () == keyed_tensors [0 ].device () for kt in keyed_tensors
54
+ ), "All inputs should be on the same device."
55
+ module .device = keyed_tensors [0 ].device ()
56
+ assert all (
57
+ kt .key_dim () == keyed_tensors [0 ].key_dim () for kt in keyed_tensors
58
+ ), "All inputs should have the same key_dim"
59
+ module ._dim = keyed_tensors [0 ].key_dim ()
60
+
61
+ if module ._dim == 1 :
62
+ module ._init_fbgemm_regroup (keyed_tensors )
63
+ else :
64
+ module ._init_regroup (keyed_tensors )
65
+ module ._is_inited = True
44
66
45
67
46
68
class KTRegroupAsDict (torch .nn .Module ):
@@ -76,27 +98,26 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
76
98
77
99
# cached values populated on first forward call
78
100
self .device : Optional [torch .device ] = None
79
- self ._concat_dim : int = 1
101
+ self ._dim : int = 1
80
102
self ._use_fbgemm_regroup : bool = False
81
103
self ._splits : List [int ] = []
82
104
self ._idx_key_pairs : List [Tuple [int , str ]] = []
83
- self ._permute_tensor : Optional [ torch .Tensor ] = None
84
- self ._inv_permute_tensor : Optional [ torch .Tensor ] = None
85
- self ._offsets_tensor : Optional [ torch .Tensor ] = None
86
- self ._inv_offsets_tensor : Optional [torch . Tensor ] = None
105
+ self .register_buffer ( "_permutes" , torch .empty ( 2 ), persistent = True )
106
+ self .register_buffer ( "_in_shapes" , torch .empty ( 2 ), persistent = True )
107
+ self .register_buffer ( "_out_shapes" , torch .empty ( 2 ), persistent = True )
108
+ self ._out_lengths : Optional [List [ int ] ] = None
87
109
88
110
def _init_fbgemm_regroup (self , kts : List [KeyedTensor ]) -> None :
89
111
self ._use_fbgemm_regroup = True
90
112
keys , lengths , values = _desugar_keyed_tensors (kts )
91
- permute , inv_permute , offsets , inv_offsets , splits = _remap_to_groups (
92
- keys , lengths , self ._groups
113
+ self ._permutes , self ._in_shapes , self ._out_shapes , self ._out_lengths = (
114
+ _kt_regroup_arguments (
115
+ values [0 ],
116
+ keys ,
117
+ lengths ,
118
+ self ._groups ,
119
+ )
93
120
)
94
- # no need to pin_memory() or to(..., non_blocking=True) since occurs only once
95
- self ._permute_tensor = permute .to (self .device )
96
- self ._inv_permute_tensor = inv_permute .to (self .device )
97
- self ._offsets_tensor = offsets .to (self .device )
98
- self ._inv_offsets_tensor = inv_offsets .to (self .device )
99
- self ._splits = splits
100
121
101
122
def _init_regroup (self , kts : List [KeyedTensor ]) -> None :
102
123
lengths = [kt .length_per_key () for kt in kts ]
@@ -127,34 +148,19 @@ def _init_regroup(self, kts: List[KeyedTensor]) -> None:
127
148
128
149
def forward (self , keyed_tensors : List [KeyedTensor ]) -> Dict [str , torch .Tensor ]:
129
150
if not self ._is_inited :
130
- assert len (keyed_tensors ) > 0 , "Empty list provided"
131
- assert all (
132
- kt .device () == keyed_tensors [0 ].device () for kt in keyed_tensors
133
- ), "All inputs should be on the same device."
134
- self .device = keyed_tensors [0 ].device ()
135
- assert all (
136
- kt .key_dim () == keyed_tensors [0 ].key_dim () for kt in keyed_tensors
137
- ), "All inputs should have the same key_dim"
138
- self ._dim = keyed_tensors [0 ].key_dim ()
139
-
140
- if _all_keys_used_once (keyed_tensors , self ._groups ) and self ._dim == 1 :
141
- self ._init_fbgemm_regroup (keyed_tensors )
142
- else :
143
- self ._init_regroup (keyed_tensors )
144
- self ._is_inited = True
151
+ module_init (self , keyed_tensors )
145
152
146
153
if self ._use_fbgemm_regroup :
147
- values = _concat_values (keyed_tensors , self . _dim )
148
- permuted_values = torch .ops .fbgemm .permute_pooled_embs_auto_grad (
154
+ values = _get_kts_values (keyed_tensors )
155
+ permuted_values = torch .ops .fbgemm .permute_multi_embedding (
149
156
values ,
150
- self ._offsets_tensor ,
151
- self ._permute_tensor ,
152
- self ._inv_offsets_tensor ,
153
- self ._inv_permute_tensor ,
157
+ self ._permutes ,
158
+ self ._in_shapes ,
159
+ self ._out_shapes ,
160
+ self ._out_lengths ,
154
161
)
155
162
else :
156
163
permuted_values = _permuted_values (
157
164
keyed_tensors , self ._idx_key_pairs , self ._dim
158
165
)
159
-
160
166
return _build_dict (self ._keys , permuted_values , self ._splits , self ._dim )
0 commit comments