Skip to content

short circuit the flatten/unflatten between EBC and KTRegroupAsDict modules #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_key_order_with_ebc_and_regroup(self) -> None:
ebc2.load_state_dict(ebc1.state_dict())
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])

class myModel(nn.Module):
class mySparse(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.ebc = ebc
Expand All @@ -569,6 +569,17 @@ def forward(
) -> Dict[str, torch.Tensor]:
return self.regroup([self.ebc(features)])

class myModel(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.sparse = mySparse(ebc, regroup)

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
return self.sparse(features)

model = myModel(ebc1, regroup)
eager_out = model(id_list_features)

Expand All @@ -582,11 +593,17 @@ def forward(
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = decapsulate_ir_modules(
unflatten_ep,
JsonSerializer,
short_circuit_pytree_ebc_regroup=True,
finalize_interpreter_modules=True,
)

# we export the model with ebc1 and unflatten the model,
# and then swap with ebc2 (you can think this as the the sharding process
# resulting a shardedEBC), so that we can mimic the key-order change
deserialized_model.ebc = ebc2
deserialized_model.sparse.ebc = ebc2

deserialized_out = deserialized_model(id_list_features)
for key in eager_out.keys():
Expand Down
118 changes: 118 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#!/usr/bin/env python3

import logging
import operator
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type

Expand All @@ -18,7 +19,12 @@
from torch import nn
from torch.export import Dim, ShapesCollection
from torch.export.dynamic_shapes import _Dim as DIM
from torch.export.unflatten import InterpreterModule
from torch.fx import Node
from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


Expand Down Expand Up @@ -129,6 +135,8 @@ def decapsulate_ir_modules(
module: nn.Module,
serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
device: Optional[torch.device] = None,
finalize_interpreter_modules: bool = False,
short_circuit_pytree_ebc_regroup: bool = False,
) -> nn.Module:
"""
Takes a module and decapsulate its embedding modules by retrieving the buffer.
Expand All @@ -147,6 +155,16 @@ def decapsulate_ir_modules(
# we use "ir_metadata" as a convention to identify the deserializable module
if "ir_metadata" in dict(module.named_buffers()):
module = serializer.decapsulate_module(module, device)

if short_circuit_pytree_ebc_regroup:
module = _short_circuit_pytree_ebc_regroup(module)
assert finalize_interpreter_modules, "need finalize_interpreter_modules=True"

if finalize_interpreter_modules:
for mod in module.modules():
if isinstance(mod, InterpreterModule):
mod.finalize()

return module


Expand Down Expand Up @@ -233,3 +251,103 @@ def move_to_copy_nodes_to_device(
nodes.kwargs = new_kwargs

return unflattened_module


def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module:
"""
Bypass pytree flatten and unflatten function between EBC and KTRegroupAsDict to avoid key-order issue.
https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/
EBC ==> (out-going) pytree.flatten ==> tensors and specs ==> (in-coming) pytree.unflatten ==> KTRegroupAsDict
"""
ebc_fqns: List[str] = []
regroup_fqns: List[str] = []
for fqn, m in module.named_modules():
if isinstance(m, FeatureProcessedEmbeddingBagCollection):
ebc_fqns.append(fqn)
elif isinstance(m, EmbeddingBagCollection):
if len(ebc_fqns) > 0 and fqn.startswith(ebc_fqns[-1]):
continue
ebc_fqns.append(fqn)
elif isinstance(m, KTRegroupAsDict):
regroup_fqns.append(fqn)
if len(ebc_fqns) == len(regroup_fqns) == 0:
# nothing happens if there is no EBC or KTRegroupAsDict (e.g., the PEA case)
return module
elif len(regroup_fqns) == 0:
# model only contains EBCs, KT (from EBC) pytree.flatten has performance impact
logger.warning(
"Expect perf impact if KTRegroupAsDict is not used together with EBCs."
)
return module
elif len(ebc_fqns) == 0:
# model only contains KTRegroupAsDict, KTs are not from EBC, need to be careful
logger.warning("KTRegroupAsDict is not from EBC, need to be careful.")
return module
else:
return prune_pytree_flatten_unflatten(
module, in_fqns=regroup_fqns, out_fqns=ebc_fqns
)


def prune_pytree_flatten_unflatten(
module: nn.Module, in_fqns: List[str], out_fqns: List[str]
) -> nn.Module:
"""
Remove pytree flatten and unflatten function between the given in_fqns and out_fqns.
"preserved module" ==> (out-going) pytree.flatten ==> [tensors and specs]
[tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module"
"""

def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:
for node in mod.graph.nodes:
if node.op == "call_module" and node.target == fqn:
return mod, node
assert "." in fqn, f"can't find {fqn} in the graph of {mod}"
curr, fqn = fqn.split(".", maxsplit=1)
mod = getattr(mod, curr)
return _get_graph_node(mod, fqn)

# remove tree_unflatten from the in_fqns (in-coming nodes)
for fqn in in_fqns:
submodule, node = _get_graph_node(module, fqn)
assert len(node.args) == 1
getitem_getitem: Node = node.args[0] # pyre-ignore[9]
assert (
getitem_getitem.op == "call_function"
and getitem_getitem.target == operator.getitem
)
tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16]
assert (
tree_unflatten_getitem.op == "call_function"
and tree_unflatten_getitem.target == operator.getitem
)
tree_unflatten = tree_unflatten_getitem.args[0]
assert (
tree_unflatten.op == "call_function"
and tree_unflatten.target == torch.utils._pytree.tree_unflatten
)
logger.info(f"Removing tree_unflatten from {fqn}")
input_nodes = tree_unflatten.args[0]
node.args = (input_nodes,)
submodule.graph.eliminate_dead_code()

# remove tree_flatten_spec from the out_fqns (out-going nodes)
for fqn in out_fqns:
submodule, node = _get_graph_node(module, fqn)
users = list(node.users.keys())
assert (
len(users) == 1
and users[0].op == "call_function"
and users[0].target == torch.fx._pytree.tree_flatten_spec
)
tree_flatten_users = list(users[0].users.keys())
assert (
len(tree_flatten_users) == 1
and tree_flatten_users[0].op == "call_function"
and tree_flatten_users[0].target == operator.getitem
)
logger.info(f"Removing tree_flatten_spec from {fqn}")
getitem_node = tree_flatten_users[0]
getitem_node.replace_all_uses_with(node)
submodule.graph.eliminate_dead_code()
return module
Loading