9
9
10
10
#!/usr/bin/env python3
11
11
12
+ import copy
12
13
import unittest
13
14
14
15
import torch
15
16
from torch import nn
16
17
from torchrec .ir .serializer import JsonSerializer
17
18
18
19
from torchrec .ir .utils import deserialize_embedding_modules , serialize_embedding_modules
20
+ from torchrec .modules import utils as module_utils
19
21
20
22
from torchrec .modules .embedding_configs import EmbeddingBagConfig
21
23
from torchrec .modules .embedding_modules import EmbeddingBagCollection
24
+ from torchrec .modules .utils import (
25
+ operator_registry_state ,
26
+ register_custom_op ,
27
+ register_custom_ops_for_nodes ,
28
+ )
22
29
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
23
30
24
31
@@ -27,13 +34,29 @@ def generate_model(self) -> nn.Module:
27
34
class Model (nn .Module ):
28
35
def __init__ (self , ebc ):
29
36
super ().__init__ ()
30
- self .sparse_arch = ebc
37
+ self .ebc1 = ebc
38
+ self .ebc2 = copy .deepcopy (ebc )
39
+ self .ebc3 = copy .deepcopy (ebc )
40
+ self .ebc4 = copy .deepcopy (ebc )
41
+ self .ebc5 = copy .deepcopy (ebc )
31
42
32
43
def forward (
33
44
self ,
34
45
features : KeyedJaggedTensor ,
35
- ) -> KeyedTensor :
36
- return self .sparse_arch (features )
46
+ ) -> torch .Tensor :
47
+ kt1 = self .ebc1 (features )
48
+ kt2 = self .ebc2 (features )
49
+ kt3 = self .ebc3 (features )
50
+ kt4 = self .ebc4 (features )
51
+ kt5 = self .ebc5 (features )
52
+
53
+ return (
54
+ kt1 .values ()
55
+ + kt2 .values ()
56
+ + kt3 .values ()
57
+ + kt4 .values ()
58
+ + kt5 .values ()
59
+ )
37
60
38
61
tb1_config = EmbeddingBagConfig (
39
62
name = "t1" ,
@@ -65,7 +88,7 @@ def test_serialize_deserialize_ebc(self) -> None:
65
88
offsets = torch .tensor ([0 , 2 , 2 , 3 , 4 ]),
66
89
)
67
90
68
- eager_kt = model (id_list_features )
91
+ eager_out = model (id_list_features )
69
92
70
93
# Serialize PEA
71
94
model , sparse_fqns = serialize_embedding_modules (model , JsonSerializer )
@@ -78,37 +101,66 @@ def test_serialize_deserialize_ebc(self) -> None:
78
101
preserve_module_call_signature = (tuple (sparse_fqns )),
79
102
)
80
103
81
- # Run forward on ExportedProgram
82
- ep_output = ep .module ()(id_list_features )
104
+ total_dim = sum (model .ebc1 ._lengths_per_embedding )
105
+ with operator_registry_state .op_registry_lock :
106
+ # Run forward on ExportedProgram
107
+ ep_output = ep .module ()(id_list_features )
83
108
84
- self .assertTrue (isinstance (ep_output , KeyedTensor ))
85
- self .assertEqual (eager_kt .keys (), ep_output .keys ())
86
- self .assertEqual (eager_kt .values ().shape , ep_output .values ().shape )
109
+ self .assertEqual (eager_out .shape , ep_output .shape )
87
110
88
- # Deserialize EBC
89
- deserialized_model = deserialize_embedding_modules (ep , JsonSerializer )
90
-
91
- self .assertTrue (
92
- isinstance (deserialized_model .sparse_arch , EmbeddingBagCollection )
93
- )
111
+ # Only 1 custom op registered, as dimensions of ebc are same
112
+ self .assertEqual (len (operator_registry_state .op_registry_schema ), 1 )
94
113
95
- for deserialized_config , org_config in zip (
96
- deserialized_model .sparse_arch .embedding_bag_configs (),
97
- model .sparse_arch .embedding_bag_configs (),
98
- ):
99
- self .assertEqual (deserialized_config .name , org_config .name )
100
- self .assertEqual (
101
- deserialized_config .embedding_dim , org_config .embedding_dim
114
+ # Check if custom op is registered with the correct name
115
+ # EmbeddingBagCollection type and total dim
116
+ self .assertTrue (
117
+ f"EmbeddingBagCollection_{ total_dim } "
118
+ in operator_registry_state .op_registry_schema
102
119
)
103
- self .assertEqual (
104
- deserialized_config .num_embeddings , org_config .num_embeddings
120
+
121
+ # Reset the op registry
122
+ operator_registry_state .op_registry_schema = {}
123
+
124
+ # Reset lib
125
+ module_utils .lib = torch .library .Library ("custom" , "FRAGMENT" )
126
+
127
+ # Ensure custom op is reregistered
128
+ register_custom_ops_for_nodes (list (ep .graph_module .graph .nodes ))
129
+
130
+ with operator_registry_state .op_registry_lock :
131
+ self .assertTrue (
132
+ f"EmbeddingBagCollection_{ total_dim } "
133
+ in operator_registry_state .op_registry_schema
105
134
)
106
- self .assertEqual (
107
- deserialized_config .feature_names , org_config .feature_names
135
+
136
+ ep .module ()(id_list_features )
137
+ # Deserialize EBC
138
+ deserialized_model = deserialize_embedding_modules (ep , JsonSerializer )
139
+
140
+ for i in range (5 ):
141
+ ebc_name = f"ebc{ i + 1 } "
142
+ self .assertTrue (
143
+ isinstance (
144
+ getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
145
+ )
108
146
)
109
147
148
+ for deserialized_config , org_config in zip (
149
+ getattr (deserialized_model , ebc_name ).embedding_bag_configs (),
150
+ getattr (model , ebc_name ).embedding_bag_configs (),
151
+ ):
152
+ self .assertEqual (deserialized_config .name , org_config .name )
153
+ self .assertEqual (
154
+ deserialized_config .embedding_dim , org_config .embedding_dim
155
+ )
156
+ self .assertEqual (
157
+ deserialized_config .num_embeddings , org_config .num_embeddings
158
+ )
159
+ self .assertEqual (
160
+ deserialized_config .feature_names , org_config .feature_names
161
+ )
162
+
110
163
# Run forward on deserialized model
111
- deserialized_kt = deserialized_model (id_list_features )
164
+ deserialized_out = deserialized_model (id_list_features )
112
165
113
- self .assertEqual (eager_kt .keys (), deserialized_kt .keys ())
114
- self .assertEqual (eager_kt .values ().shape , deserialized_kt .values ().shape )
166
+ self .assertEqual (eager_out .shape , deserialized_out .shape )
0 commit comments