Skip to content

[TorchRec][IR] Add IR serializer for KTRegroupAsDict Module #1900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchrec/ir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ class PositionWeightedModuleMetadata:
@dataclass
class PositionWeightedModuleCollectionMetadata:
max_feature_lengths: List[Tuple[str, int]]


@dataclass
class KTRegroupAsDictMetadata:
groups: List[List[str]]
keys: List[str]
60 changes: 58 additions & 2 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
# pyre-strict

import json
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Type

import torch

from torch import nn
from torchrec.ir.schema import (
EBCMetadata,
EmbeddingBagConfigMetadata,
FPEBCMetadata,
KTRegroupAsDictMetadata,
PositionWeightedModuleCollectionMetadata,
PositionWeightedModuleMetadata,
)
Expand All @@ -32,6 +32,7 @@
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor


Expand Down Expand Up @@ -125,6 +126,22 @@ def fpebc_meta_forward(
)


def kt_regroup_meta_forward(
op_module: KTRegroupAsDict, keyed_tensors: List[KeyedTensor]
) -> Dict[str, torch.Tensor]:
lengths_dict: Dict[str, int] = {}
batch_size = keyed_tensors[0].values().size(0)
for kt in keyed_tensors:
for key, length in zip(kt.keys(), kt.length_per_key()):
lengths_dict[key] = length
out_lengths: List[int] = [0] * len(op_module._groups)
for i, group in enumerate(op_module._groups):
out_lengths[i] = sum(lengths_dict[key] for key in group)
arg_list = [kt.values() for kt in keyed_tensors]
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, out_lengths)
return dict(zip(op_module._keys, outputs))


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
Expand Down Expand Up @@ -364,3 +381,42 @@ def deserialize_from_dict(
JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = (
FPEBCJsonSerializer
)


class KTRegroupAsDictJsonSerializer(JsonSerializer):
_module_cls = KTRegroupAsDict

@classmethod
def swap_meta_forward(cls, module: nn.Module) -> None:
assert isinstance(module, cls._module_cls)
# pyre-ignore
module.forward = kt_regroup_meta_forward.__get__(module, cls._module_cls)

@classmethod
def serialize_to_dict(
cls,
module: nn.Module,
) -> Dict[str, Any]:
metadata = KTRegroupAsDictMetadata(
keys=module._keys,
groups=module._groups,
)
return metadata.__dict__

@classmethod
def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten_ep: Optional[nn.Module] = None,
) -> nn.Module:
metadata = KTRegroupAsDictMetadata(**metadata_dict)
return KTRegroupAsDict(
keys=metadata.keys,
groups=metadata.groups,
)


JsonSerializer.module_to_serializer_cls["KTRegroupAsDict"] = (
KTRegroupAsDictJsonSerializer
)
88 changes: 88 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor


Expand Down Expand Up @@ -433,3 +434,90 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
self.assertEqual(len(deserialized_out), len(eager_out))
for x, y in zip(deserialized_out, eager_out):
self.assertTrue(torch.allclose(x, y))

def test_regroup_as_dict_module(self) -> None:
class Model(nn.Module):
def __init__(self, ebc, fpebc, regroup):
super().__init__()
self.ebc = ebc
self.fpebc = fpebc
self.regroup = regroup

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
kt1 = self.ebc(features)
kt2 = self.fpebc(features)
return self.regroup([kt1, kt2])

tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1", "f2"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f3", "f4"],
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f5"],
)

ebc = EmbeddingBagCollection(
tables=[tb1_config, tb3_config],
is_weighted=False,
)
max_feature_lengths = {"f3": 100, "f4": 100}
fpebc = FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
tables=[tb2_config],
is_weighted=True,
),
PositionWeightedModuleCollection(
max_feature_lengths=max_feature_lengths,
),
)
regroup = KTRegroupAsDict([["f1", "f3", "f5"], ["f2", "f4"]], ["odd", "even"])
model = Model(ebc, fpebc, regroup)

id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3", "f4", "f5"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
)
self.assertFalse(model.regroup._is_inited)

# Serialize EBC
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)

self.assertFalse(model.regroup._is_inited)
eager_out = model(id_list_features)
self.assertFalse(model.regroup._is_inited)

# Run forward on ExportedProgram
ep_output = ep.module()(id_list_features)
for key in eager_out.keys():
self.assertEqual(ep_output[key].shape, eager_out[key].shape)
# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
self.assertFalse(deserialized_model.regroup._is_inited)
deserialized_out = deserialized_model(id_list_features)
self.assertTrue(deserialized_model.regroup._is_inited)
for key in eager_out.keys():
self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)
Loading