Skip to content

Commit cc271f1

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Simplify custom op naming for meta functionalization of sparse modules
Summary: Registering custom ops for meta functionalization with ids can lead to hash collisions, resulting in wrong dimensions for a sparse module. This diff replaces custom op naming to just the dimension that is returned, alongside the module type, to ensure that the right dimensions are always returned and simplify the custom op naming logic significantly. Differential Revision: D57108438
1 parent 35a7f93 commit cc271f1

File tree

3 files changed

+146
-78
lines changed

3 files changed

+146
-78
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,23 @@
99

1010
#!/usr/bin/env python3
1111

12+
import copy
1213
import unittest
1314

1415
import torch
1516
from torch import nn
1617
from torchrec.ir.serializer import JsonSerializer
1718

1819
from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules
20+
from torchrec.modules import utils as module_utils
1921

2022
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2123
from torchrec.modules.embedding_modules import EmbeddingBagCollection
24+
from torchrec.modules.utils import (
25+
operator_registry_state,
26+
register_custom_op,
27+
register_custom_ops_for_nodes,
28+
)
2229
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
2330

2431

@@ -27,13 +34,29 @@ def generate_model(self) -> nn.Module:
2734
class Model(nn.Module):
2835
def __init__(self, ebc):
2936
super().__init__()
30-
self.sparse_arch = ebc
37+
self.ebc1 = ebc
38+
self.ebc2 = copy.deepcopy(ebc)
39+
self.ebc3 = copy.deepcopy(ebc)
40+
self.ebc4 = copy.deepcopy(ebc)
41+
self.ebc5 = copy.deepcopy(ebc)
3142

3243
def forward(
3344
self,
3445
features: KeyedJaggedTensor,
35-
) -> KeyedTensor:
36-
return self.sparse_arch(features)
46+
) -> torch.Tensor:
47+
kt1 = self.ebc1(features)
48+
kt2 = self.ebc2(features)
49+
kt3 = self.ebc3(features)
50+
kt4 = self.ebc4(features)
51+
kt5 = self.ebc5(features)
52+
53+
return (
54+
kt1.values()
55+
+ kt2.values()
56+
+ kt3.values()
57+
+ kt4.values()
58+
+ kt5.values()
59+
)
3760

3861
tb1_config = EmbeddingBagConfig(
3962
name="t1",
@@ -65,7 +88,7 @@ def test_serialize_deserialize_ebc(self) -> None:
6588
offsets=torch.tensor([0, 2, 2, 3, 4]),
6689
)
6790

68-
eager_kt = model(id_list_features)
91+
eager_out = model(id_list_features)
6992

7093
# Serialize PEA
7194
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
@@ -78,37 +101,66 @@ def test_serialize_deserialize_ebc(self) -> None:
78101
preserve_module_call_signature=(tuple(sparse_fqns)),
79102
)
80103

81-
# Run forward on ExportedProgram
82-
ep_output = ep.module()(id_list_features)
104+
total_dim = sum(model.ebc1._lengths_per_embedding)
105+
with operator_registry_state.op_registry_lock:
106+
# Run forward on ExportedProgram
107+
ep_output = ep.module()(id_list_features)
83108

84-
self.assertTrue(isinstance(ep_output, KeyedTensor))
85-
self.assertEqual(eager_kt.keys(), ep_output.keys())
86-
self.assertEqual(eager_kt.values().shape, ep_output.values().shape)
109+
self.assertEqual(eager_out.shape, ep_output.shape)
87110

88-
# Deserialize EBC
89-
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
90-
91-
self.assertTrue(
92-
isinstance(deserialized_model.sparse_arch, EmbeddingBagCollection)
93-
)
111+
# Only 1 custom op registered, as dimensions of ebc are same
112+
self.assertEqual(len(operator_registry_state.op_registry_schema), 1)
94113

95-
for deserialized_config, org_config in zip(
96-
deserialized_model.sparse_arch.embedding_bag_configs(),
97-
model.sparse_arch.embedding_bag_configs(),
98-
):
99-
self.assertEqual(deserialized_config.name, org_config.name)
100-
self.assertEqual(
101-
deserialized_config.embedding_dim, org_config.embedding_dim
114+
# Check if custom op is registered with the correct name
115+
# EmbeddingBagCollection type and total dim
116+
self.assertTrue(
117+
f"EmbeddingBagCollection_{total_dim}"
118+
in operator_registry_state.op_registry_schema
102119
)
103-
self.assertEqual(
104-
deserialized_config.num_embeddings, org_config.num_embeddings
120+
121+
# Reset the op registry
122+
operator_registry_state.op_registry_schema = {}
123+
124+
# Reset lib
125+
module_utils.lib = torch.library.Library("custom", "FRAGMENT")
126+
127+
# Ensure custom op is reregistered
128+
register_custom_ops_for_nodes(list(ep.graph_module.graph.nodes))
129+
130+
with operator_registry_state.op_registry_lock:
131+
self.assertTrue(
132+
f"EmbeddingBagCollection_{total_dim}"
133+
in operator_registry_state.op_registry_schema
105134
)
106-
self.assertEqual(
107-
deserialized_config.feature_names, org_config.feature_names
135+
136+
ep.module()(id_list_features)
137+
# Deserialize EBC
138+
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
139+
140+
for i in range(5):
141+
ebc_name = f"ebc{i + 1}"
142+
self.assertTrue(
143+
isinstance(
144+
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
145+
)
108146
)
109147

148+
for deserialized_config, org_config in zip(
149+
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
150+
getattr(model, ebc_name).embedding_bag_configs(),
151+
):
152+
self.assertEqual(deserialized_config.name, org_config.name)
153+
self.assertEqual(
154+
deserialized_config.embedding_dim, org_config.embedding_dim
155+
)
156+
self.assertEqual(
157+
deserialized_config.num_embeddings, org_config.num_embeddings
158+
)
159+
self.assertEqual(
160+
deserialized_config.feature_names, org_config.feature_names
161+
)
162+
110163
# Run forward on deserialized model
111-
deserialized_kt = deserialized_model(id_list_features)
164+
deserialized_out = deserialized_model(id_list_features)
112165

113-
self.assertEqual(eager_kt.keys(), deserialized_kt.keys())
114-
self.assertEqual(eager_kt.values().shape, deserialized_kt.values().shape)
166+
self.assertEqual(eager_out.shape, deserialized_out.shape)

torchrec/modules/embedding_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _non_strict_exporting_forward(
217217
features.offsets_or_none(),
218218
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
219219
dims = [sum(self._lengths_per_embedding)]
220-
ebc_op = register_custom_op(self, dims)
220+
ebc_op = register_custom_op(type(self).__name__, dims)
221221
outputs = ebc_op(arg_list, batch_size)
222222
return KeyedTensor(
223223
keys=self._embedding_names,

torchrec/modules/utils.py

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ class OpRegistryState:
5050

5151
# operator schema: {class}.{id} => op_name
5252
op_registry_schema: Dict[str, str] = {}
53-
# operator counter: {class} => count
54-
op_registry_counter: Dict[str, int] = defaultdict(int)
5553

5654

5755
operator_registry_state = OpRegistryState()
@@ -274,7 +272,8 @@ def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
274272
# a list of tensors as output. The operator is registered with the name of
275273
# {module_class_name}_{instance_count}
276274
def register_custom_op(
277-
module: torch.nn.Module, dims: List[int]
275+
module_name: str,
276+
dims: List[int],
278277
) -> Callable[[List[Optional[torch.Tensor]], int], List[torch.Tensor]]:
279278
"""
280279
Register a customized operator.
@@ -286,51 +285,68 @@ def register_custom_op(
286285

287286
global operator_registry_state
288287

289-
m_name: str = type(module).__name__
290-
op_id: str = f"{m_name}_{id(module)}"
288+
dims_str = "_".join([str(d) for d in dims])
291289
with operator_registry_state.op_registry_lock:
292-
if op_id in operator_registry_state.op_registry_schema:
293-
op_name: str = operator_registry_state.op_registry_schema[op_id]
294-
else:
295-
operator_registry_state.op_registry_counter[m_name] += 1
296-
op_name: str = (
297-
f"{m_name}_{operator_registry_state.op_registry_counter[m_name]}"
298-
)
299-
operator_registry_state.op_registry_schema[op_id] = op_name
300-
301-
def custom_op(
302-
values: List[Optional[torch.Tensor]],
303-
batch_size: int,
304-
) -> List[torch.Tensor]:
305-
device = None
306-
for v in values:
307-
if v is not None:
308-
device = v.device
309-
break
310-
else:
311-
raise AssertionError(
312-
f"Custom op {op_name} expects at least one input tensor"
313-
)
314-
315-
return [
316-
torch.empty(
317-
batch_size,
318-
dim,
319-
device=device,
320-
)
321-
for dim in dims
322-
]
323-
324-
schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
325-
operator_registry_state.op_registry_schema[op_name] = schema_string
326-
# Register schema
327-
lib.define(schema_string)
328-
329-
# Register implementation
330-
lib.impl(op_name, custom_op, "CPU")
331-
lib.impl(op_name, custom_op, "CUDA")
332-
333-
# Register meta formula
334-
lib.impl(op_name, custom_op, "Meta")
290+
op_name: str = f"{module_name}_{dims_str}"
291+
292+
if op_name in operator_registry_state.op_registry_schema:
293+
return getattr(torch.ops.custom, op_name)
294+
295+
def custom_op(
296+
values: List[Optional[torch.Tensor]],
297+
batch_size: int,
298+
) -> List[torch.Tensor]:
299+
device = None
300+
for v in values:
301+
if v is not None:
302+
device = v.device
303+
break
304+
else:
305+
raise AssertionError(
306+
f"Custom op {op_name} expects at least one input tensor"
307+
)
308+
309+
return [
310+
torch.empty(
311+
batch_size,
312+
dim,
313+
device=device,
314+
)
315+
for dim in dims
316+
]
317+
318+
schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
319+
operator_registry_state.op_registry_schema[op_name] = schema_string
320+
# Register schema
321+
lib.define(schema_string)
322+
323+
# Register implementation
324+
lib.impl(op_name, custom_op, "CPU")
325+
lib.impl(op_name, custom_op, "CUDA")
326+
327+
# Register meta formula
328+
lib.impl(op_name, custom_op, "Meta")
335329

336330
return getattr(torch.ops.custom, op_name)
331+
332+
333+
def register_custom_ops_for_nodes(
334+
nodes: List[torch.fx.Node],
335+
) -> None:
336+
"""
337+
Given a list of nodes, register custom ops if they exist in the nodes.
338+
Required for deserialization if in different runtime environments
339+
340+
Args:
341+
nodes: list of nodes
342+
"""
343+
344+
for node in nodes:
345+
if "custom." in str(node.target):
346+
# torch.ops.custom.EmbeddingBagCollection_100.default
347+
# number represents dimension
348+
op_name = str(node.target).split(".")[-2]
349+
register_custom_op(
350+
op_name.split("_")[0],
351+
[int(dim) for dim in op_name.split("_")[1:]],
352+
)

0 commit comments

Comments
 (0)