diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py index f0fbf706e..0560812bd 100644 --- a/torchrec/ir/schema.py +++ b/torchrec/ir/schema.py @@ -8,7 +8,7 @@ # pyre-strict from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from torchrec.modules.embedding_configs import DataType, PoolingType @@ -32,3 +32,19 @@ class EBCMetadata: tables: List[EmbeddingBagConfigMetadata] is_weighted: bool device: Optional[str] + + +@dataclass +class FPEBCMetadata: + is_fp_collection: bool + features: List[str] + + +@dataclass +class PositionWeightedModuleMetadata: + max_feature_length: int + + +@dataclass +class PositionWeightedModuleCollectionMetadata: + max_feature_lengths: List[Tuple[str, int]] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index ffc1fe69a..24bd77954 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -8,17 +8,32 @@ # pyre-strict import json -import logging from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch import nn -from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata +from torchrec.ir.schema import ( + EBCMetadata, + EmbeddingBagConfigMetadata, + FPEBCMetadata, + PositionWeightedModuleCollectionMetadata, + PositionWeightedModuleMetadata, +) from torchrec.ir.types import SerializerInterface +from torchrec.ir.utils import logging from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + FeatureProcessorsCollection, + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + logger: logging.Logger = logging.getLogger(__name__) @@ -69,6 +84,26 @@ def get_deserialized_device( return device +def ebc_meta_forward( + ebc: EmbeddingBagCollection, + features: KeyedJaggedTensor, +) -> KeyedTensor: + batch_size = features.stride() + dim = sum(ebc._lengths_per_embedding) + arg_list = [ + features.values(), + features.weights_or_none(), + features.lengths_or_none(), + features.offsets_or_none(), + ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` + output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim) + return KeyedTensor( + keys=ebc._embedding_names, + values=output, + length_per_key=ebc._lengths_per_embedding, + ) + + class JsonSerializer(SerializerInterface): """ Serializer for torch.export IR using json. @@ -150,10 +185,70 @@ def deserialize( ) return module + @classmethod + def swap_meta_forward(cls, module: nn.Module) -> None: + pass + + @classmethod + def encapsulate_module(cls, module: nn.Module) -> List[str]: + typename = type(module).__name__ + serializer = cls.module_to_serializer_cls.get(typename) + if serializer is None: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + assert issubclass(serializer, JsonSerializer) + assert serializer._module_cls is not None + if not isinstance(module, serializer._module_cls): + raise ValueError( + f"Expected module to be of type {serializer._module_cls.__name__}, " + f"got {type(module)}" + ) + metadata_dict = serializer.serialize_to_dict(module) + raw_dict = {"typename": typename, "metadata_dict": metadata_dict} + ir_metadata_tensor = torch.frombuffer( + json.dumps(raw_dict).encode(), dtype=torch.uint8 + ) + module.register_buffer("ir_metadata", ir_metadata_tensor, persistent=False) + serializer.swap_meta_forward(module) + return serializer.children(module) + + @classmethod + def decapsulate_module( + cls, module: nn.Module, device: Optional[torch.device] = None + ) -> nn.Module: + raw_bytes = module.get_buffer("ir_metadata").numpy().tobytes() + raw_dict = json.loads(raw_bytes.decode()) + typename = raw_dict["typename"] + metadata_dict = raw_dict["metadata_dict"] + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + serializer = cls.module_to_serializer_cls[typename] + assert issubclass(serializer, JsonSerializer) + module = serializer.deserialize_from_dict(metadata_dict, device, module) + + if serializer._module_cls is None: + raise ValueError( + "Must assign a nn.Module to class static variable _module_cls" + ) + if not isinstance(module, serializer._module_cls): + raise ValueError( + f"Expected module to be of type {serializer._module_cls.__name__}, got {type(module)}" + ) + return module + class EBCJsonSerializer(JsonSerializer): _module_cls = EmbeddingBagCollection + @classmethod + def swap_meta_forward(cls, module: nn.Module) -> None: + assert isinstance(module, cls._module_cls) + # pyre-ignore + module.forward = ebc_meta_forward.__get__(module, cls._module_cls) + @classmethod def serialize_to_dict( cls, @@ -196,3 +291,110 @@ def deserialize_from_dict( JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer + + +class PWMJsonSerializer(JsonSerializer): + _module_cls = PositionWeightedModule + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + metadata = PositionWeightedModuleMetadata( + max_feature_length=module.position_weight.shape[0], + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = PositionWeightedModuleMetadata(**metadata_dict) + return PositionWeightedModule(metadata.max_feature_length, device) + + +JsonSerializer.module_to_serializer_cls["PositionWeightedModule"] = PWMJsonSerializer + + +class PWMCJsonSerializer(JsonSerializer): + _module_cls = PositionWeightedModuleCollection + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + metadata = PositionWeightedModuleCollectionMetadata( + max_feature_lengths=[ # convert to list of tuples to preserve the order + (feature, len) for feature, len in module.max_feature_lengths.items() + ], + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = PositionWeightedModuleCollectionMetadata(**metadata_dict) + max_feature_lengths = { + feature: len for feature, len in metadata.max_feature_lengths + } + return PositionWeightedModuleCollection(max_feature_lengths, device) + + +JsonSerializer.module_to_serializer_cls["PositionWeightedModuleCollection"] = ( + PWMCJsonSerializer +) + + +class FPEBCJsonSerializer(JsonSerializer): + _module_cls = FeatureProcessedEmbeddingBagCollection + _children = ["_feature_processors", "_embedding_bag_collection"] + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + if isinstance(module._feature_processors, FeatureProcessorsCollection): + metadata = FPEBCMetadata( + is_fp_collection=True, + features=[], + ) + else: + metadata = FPEBCMetadata( + is_fp_collection=False, + features=list(module._feature_processors.keys()), + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = FPEBCMetadata(**metadata_dict) + assert unflatten_ep is not None + if metadata.is_fp_collection: + feature_processors = unflatten_ep._feature_processors + assert isinstance(feature_processors, FeatureProcessorsCollection) + else: + feature_processors: dict[str, FeatureProcessor] = {} + for feature in metadata.features: + fp = getattr(unflatten_ep._feature_processors, feature) + assert isinstance(fp, FeatureProcessor) + feature_processors[feature] = fp + ebc = unflatten_ep._embedding_bag_collection + assert isinstance(ebc, EmbeddingBagCollection) + return FeatureProcessedEmbeddingBagCollection( + ebc, + feature_processors, + ) + + +JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = ( + FPEBCJsonSerializer +) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 19d75b28b..88f0f69f7 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -18,16 +18,18 @@ from torchrec.ir.serializer import JsonSerializer from torchrec.ir.utils import ( - deserialize_embedding_modules, + decapsulate_ir_modules, + encapsulate_ir_modules, mark_dynamic_kjt, - serialize_embedding_modules, ) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.feature_processor_ import ( + PositionWeightedModule, + PositionWeightedModuleCollection, +) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.modules.utils import operator_registry_state from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -90,16 +92,18 @@ def deserialize_from_dict( class TestJsonSerializer(unittest.TestCase): + # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict def generate_model(self) -> nn.Module: class Model(nn.Module): - def __init__(self, ebc, fpebc): + def __init__(self, ebc, fpebc1, fpebc2): super().__init__() self.ebc1 = ebc self.ebc2 = copy.deepcopy(ebc) self.ebc3 = copy.deepcopy(ebc) self.ebc4 = copy.deepcopy(ebc) self.ebc5 = copy.deepcopy(ebc) - self.fpebc = fpebc + self.fpebc1 = fpebc1 + self.fpebc2 = fpebc2 def forward( self, @@ -111,22 +115,16 @@ def forward( kt4 = self.ebc4(features) kt5 = self.ebc5(features) - fpebc_res = self.fpebc(features) - ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]] - sparse_arch_vals = sum(ebc_kt_vals) - sparse_arch_res = KeyedTensor( - keys=kt1.keys(), - values=sparse_arch_vals, - length_per_key=kt1.length_per_key(), - ) - - return KeyedTensor.regroup( - [sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]] - ) + fpebc1_res = self.fpebc1(features) + fpebc2_res = self.fpebc2(features) + res: List[torch.Tensor] = [] + for kt in [kt1, kt2, kt3, kt4, kt5, fpebc1_res, fpebc2_res]: + res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()])) + return res tb1_config = EmbeddingBagConfig( name="t1", - embedding_dim=4, + embedding_dim=3, num_embeddings=10, feature_names=["f1"], ) @@ -138,7 +136,7 @@ def forward( ) tb3_config = EmbeddingBagConfig( name="t3", - embedding_dim=4, + embedding_dim=5, num_embeddings=10, feature_names=["f3"], ) @@ -149,7 +147,7 @@ def forward( ) max_feature_lengths = {"f1": 100, "f2": 100} - fpebc = FeatureProcessedEmbeddingBagCollection( + fpebc1 = FeatureProcessedEmbeddingBagCollection( EmbeddingBagCollection( tables=[tb1_config, tb2_config], is_weighted=True, @@ -158,8 +156,18 @@ def forward( max_feature_lengths=max_feature_lengths, ), ) + fpebc2 = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=True, + ), + { + "f1": PositionWeightedModule(max_feature_length=10), + "f3": PositionWeightedModule(max_feature_length=20), + }, + ) - model = Model(ebc, fpebc) + model = Model(ebc, fpebc1, fpebc2) return model @@ -174,7 +182,7 @@ def test_serialize_deserialize_ebc(self) -> None: eager_out = model(id_list_features) # Serialize EBC - model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, (id_list_features,), @@ -190,49 +198,55 @@ def test_serialize_deserialize_ebc(self) -> None: for i, tensor in enumerate(ep_output): self.assertEqual(eager_out[i].shape, tensor.shape) - # Only 2 custom op registered, as dimensions of ebc are same - self.assertEqual(len(operator_registry_state.op_registry_schema), 2) - - total_dim_ebc = sum(model.ebc1._lengths_per_embedding) - total_dim_fpebc = sum( - model.fpebc._embedding_bag_collection._lengths_per_embedding - ) - # Check if custom op is registered with the correct name - # EmbeddingBagCollection type and total dim - self.assertTrue( - f"EmbeddingBagCollection_{total_dim_ebc}" - in operator_registry_state.op_registry_schema - ) - self.assertTrue( - f"EmbeddingBagCollection_{total_dim_fpebc}" - in operator_registry_state.op_registry_schema - ) - # Deserialize EBC - deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + # check EBC config for i in range(5): ebc_name = f"ebc{i + 1}" - assert isinstance( + self.assertIsInstance( getattr(deserialized_model, ebc_name), EmbeddingBagCollection ) - for deserialized_config, org_config in zip( + for deserialized, orginal in zip( getattr(deserialized_model, ebc_name).embedding_bag_configs(), getattr(model, ebc_name).embedding_bag_configs(), ): - assert deserialized_config.name == org_config.name - assert deserialized_config.embedding_dim == org_config.embedding_dim - assert deserialized_config.num_embeddings, org_config.num_embeddings - assert deserialized_config.feature_names, org_config.feature_names + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # check FPEBC config + for i in range(2): + fpebc_name = f"fpebc{i + 1}" + assert isinstance( + getattr(deserialized_model, fpebc_name), + FeatureProcessedEmbeddingBagCollection, + ) + for deserialized, orginal in zip( + getattr( + deserialized_model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + getattr( + model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # Run forward on deserialized model and compare the output deserialized_model.load_state_dict(model.state_dict()) - # Run forward on deserialized model deserialized_out = deserialized_model(id_list_features) - for i, tensor in enumerate(deserialized_out): - assert eager_out[i].shape == tensor.shape - assert torch.allclose(eager_out[i], tensor) + self.assertEqual(len(deserialized_out), len(eager_out)) + for deserialized, orginal in zip(deserialized_out, eager_out): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) def test_dynamic_shape_ebc(self) -> None: model = self.generate_model() @@ -251,7 +265,7 @@ def test_dynamic_shape_ebc(self) -> None: # Serialize EBC collection = mark_dynamic_kjt(feature1) - model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, (feature1,), @@ -259,7 +273,7 @@ def test_dynamic_shape_ebc(self) -> None: dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP - preserve_module_call_signature=(tuple(sparse_fqns)), + preserve_module_call_signature=tuple(sparse_fqns), ) # Run forward on ExportedProgram @@ -270,9 +284,10 @@ def test_dynamic_shape_ebc(self) -> None: self.assertEqual(eager_out[i].shape, tensor.shape) # Deserialize EBC - deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) - + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) deserialized_model.load_state_dict(model.state_dict()) + # Run forward on deserialized model deserialized_out = deserialized_model(feature2) @@ -289,7 +304,7 @@ def test_deserialized_device(self) -> None: ) # Serialize EBC - model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, (id_list_features,), @@ -304,8 +319,9 @@ def test_deserialized_device(self) -> None: if device == "cuda" and not torch.cuda.is_available(): continue device = torch.device(device) - deserialized_model = deserialize_embedding_modules( - ep, JsonSerializer, device + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten_ep, JsonSerializer, device ) for name, m in deserialized_model.named_modules(): if hasattr(m, "device"): @@ -367,7 +383,7 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: CompoundModuleSerializer ) # Serialize - model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, (id_list_features,), @@ -383,8 +399,8 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: self.assertEqual(x.shape, y.shape) # Deserialize - deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) - + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) # Check if Compound Module is deserialized correctly self.assertIsInstance(deserialized_model.comp, CompoundModule) self.assertIsInstance(deserialized_model.comp.comp, CompoundModule) diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py index b17130548..2766dd4ab 100644 --- a/torchrec/ir/types.py +++ b/torchrec/ir/types.py @@ -47,3 +47,18 @@ def deserialize( ) -> nn.Module: # Take the bytes in the buffer and regenerate the eager embedding module raise NotImplementedError + + @classmethod + @abc.abstractmethod + def encapsulate_module(cls, module: nn.Module) -> List[str]: + # Take the eager embedding module and encapsulate the module, including serialization + # and meta_forward-swapping, then returns a list of children (fqns) which needs further encapsulation + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def decapsulate_module( + cls, module: nn.Module, device: Optional[torch.device] = None + ) -> nn.Module: + # Take the eager embedding module and decapsulate it by removing serialization and meta_forward-swapping + raise NotImplementedError diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 9f81cf973..84c12676c 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -9,6 +9,7 @@ #!/usr/bin/env python3 +import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple, Type, Union @@ -24,55 +25,74 @@ # TODO: Replace the default interface with the python dataclass interface DEFAULT_SERIALIZER_CLS = SerializerInterface DYNAMIC_DIMS: Dict[str, int] = defaultdict(int) +logger: logging.Logger = logging.getLogger(__name__) -def serialize_embedding_modules( +@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={}) +def ir_custom_op_impl( + tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int +) -> torch.Tensor: + device = None + for t in tensors: + if t is not None: + device = t.device + break + logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim})") + return torch.empty(batch_size, dim, device=device) + + +@torch.library.register_fake("torchrec::ir_custom_op") +def ir_custom_op_fake( + tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int +) -> torch.Tensor: + logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim})") + return torch.empty(batch_size, dim) + + +def encapsulate_ir_modules( module: nn.Module, - serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, - fqn: str = "", # current module's fqn for recursion purpose + serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, + fqn: str = "", ) -> Tuple[nn.Module, List[str]]: """ - Takes all the modules that are of type `serializer_cls` and serializes them - in the given format with a registered buffer to the module. - - Returns the modified module and the list of fqns that had the buffer added, - which is needed for torch.export + Takes a module and encapsulate its embedding modules and serializes them to the module buffer. + Returns the modified module and a list of fqns that had the buffer added, which is needed for torch.export + The encapsulation is done by using meta_forward function provided by the serializer + to replace the module's original forward function. """ preserve_fqns: List[str] = [] # fqns of the serialized modules children: List[str] = [] # fqns of the children that need further serialization # handle current module, and find the children which need further serialization - if type(module).__name__ in serializer_cls.module_to_serializer_cls: - serialized_tensor, children = serializer_cls.serialize(module) - module.register_buffer("ir_metadata", serialized_tensor, persistent=False) + if type(module).__name__ in serializer.module_to_serializer_cls: + children = serializer.encapsulate_module(module) preserve_fqns.append(fqn) else: - # if the module is not of type serializer_cls, then we check all its children + # if the module is not of type serializer, then we check all its children children = [child for child, _ in module.named_children()] # handle child modules recursively for child in children: submodule = module.get_submodule(child) child_fqn = f"{fqn}.{child}" if len(fqn) > 0 else child - _, fqns = serialize_embedding_modules(submodule, serializer_cls, child_fqn) + _, fqns = encapsulate_ir_modules(submodule, serializer, child_fqn) preserve_fqns.extend(fqns) return module, preserve_fqns -def _deserialize_embedding_modules( +def decapsulate_ir_modules( module: nn.Module, - serializer_cls: Type[SerializerInterface], + serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, device: Optional[torch.device] = None, ) -> nn.Module: """ - Takes the unflattened ExportedProgram module, and looks for ir_metadata buffer. - If found, deserializes the buffer and replaces the module with the deserialized module. + Takes a module and decapsulate its embedding modules by retrieving the buffer. + Returns the module with restored embedding (sub) modules. """ - for child_fqn, child in module.named_children(): # perform deserialization on the children first, so that we can replace the child module with # the deserialized module, and then replace it in the parent - child = _deserialize_embedding_modules( - module=child, serializer_cls=serializer_cls, device=device + child = decapsulate_ir_modules( + module=child, serializer=serializer, device=device ) # replace the child module with deserialized one if applicable setattr(module, child_fqn, child) @@ -80,27 +100,10 @@ def _deserialize_embedding_modules( # only deserialize if the module has ir_metadata buffer, otherwise return as is # we use "ir_metadata" as a convention to identify the deserializable module if "ir_metadata" in dict(module.named_buffers()): - ir_metadata_tensor = module.get_buffer("ir_metadata") - module = serializer_cls.deserialize(ir_metadata_tensor, device, module) + module = serializer.decapsulate_module(module, device) return module -def deserialize_embedding_modules( - ep: ExportedProgram, - serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, - device: Optional[torch.device] = None, -) -> nn.Module: - """ - Takes ExportedProgram (IR) and looks for ir_metadata buffer. - If found, deserializes the buffer and replaces the module with the deserialized - module. - - Returns the unflattened ExportedProgram with the deserialized modules. - """ - model = torch.export.unflatten(ep) - return _deserialize_embedding_modules(model, serializer_cls, device) - - def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM: if isinstance(x, DIM): return x diff --git a/torchrec/models/tests/test_dlrm.py b/torchrec/models/tests/test_dlrm.py index 194fc9517..e01976404 100644 --- a/torchrec/models/tests/test_dlrm.py +++ b/torchrec/models/tests/test_dlrm.py @@ -15,7 +15,7 @@ from torchrec.datasets.utils import Batch from torchrec.fx import symbolic_trace from torchrec.ir.serializer import JsonSerializer -from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules +from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules from torchrec.models.dlrm import ( choose, DenseArch, @@ -1263,7 +1263,7 @@ def test_export_serialization(self) -> None: self.assertEqual(logits.size(), (B, 1)) - model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, @@ -1278,7 +1278,8 @@ def test_export_serialization(self) -> None: ep_output = ep.module()(features, sparse_features) self.assertEqual(ep_output.size(), (B, 1)) - deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + unflatten_model = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_model, JsonSerializer) deserialized_logits = deserialized_model(features, sparse_features) self.assertEqual(deserialized_logits.size(), (B, 1)) diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index 934bdb229..6ae2b3ec8 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -227,7 +227,7 @@ def test_device(self) -> None: self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) self.assertEqual(torch.device("cpu"), ebc.device) - def test_exporting(self) -> None: + def test_ir_export(self) -> None: class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -296,6 +296,13 @@ def forward( "Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph", ) + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual( + output.size(), exported.size(), "Output should match exported output" + ) + class EmbeddingCollectionTest(unittest.TestCase): def test_forward(self) -> None: diff --git a/torchrec/modules/tests/test_fp_embedding_modules.py b/torchrec/modules/tests/test_fp_embedding_modules.py index 688c56be5..ff03eb3c2 100644 --- a/torchrec/modules/tests/test_fp_embedding_modules.py +++ b/torchrec/modules/tests/test_fp_embedding_modules.py @@ -21,22 +21,12 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class PositionWeightedModuleEmbeddingBagCollectionTest(unittest.TestCase): - def test_position_weighted_module_ebc(self) -> None: - # 0 1 2 <-- batch - # 0 [0,1] None [2] - # 1 [3] [4] [5,6,7] - # ^ - # feature - features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -52,8 +42,21 @@ def test_position_weighted_module_ebc(self) -> None: "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), } + return FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) + def test_position_weighted_module_ebc(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -86,6 +89,53 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), ) + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 1, + "Shoulde be exact 1 EBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size()) + + +class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -97,20 +147,11 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: ], is_weighted=True, ) - feature_processors = { - "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), - "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), - } - - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - - pooled_embeddings = fp_ebc(features) - self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) - self.assertEqual(pooled_embeddings.values().size(), (3, 16)) - self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + return FeatureProcessedEmbeddingBagCollection( + ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) + ) -class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): def test_position_weighted_collection_module_ebc(self) -> None: # 0 1 2 <-- batch # 0 [0,1] None [2] @@ -123,21 +164,7 @@ def test_position_weighted_collection_module_ebc(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) - ebc = EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="t1", embedding_dim=8, num_embeddings=16, feature_names=["f1"] - ), - EmbeddingBagConfig( - name="t2", embedding_dim=8, num_embeddings=16, feature_names=["f2"] - ), - ], - is_weighted=True, - ) - - fp_ebc = FeatureProcessedEmbeddingBagCollection( - ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) - ) + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -155,3 +182,40 @@ def test_position_weighted_collection_module_ebc(self) -> None: pooled_embeddings_gm_script.offset_per_key(), pooled_embeddings.offset_per_key(), ) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 1, + "Shoulde be exact 1 EBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size())