Skip to content

Commit 591e076

Browse files
TroyGardenPaulZhang12
authored andcommitted
Add IR serializer for KTRegroupAsDict Module (#1900)
Summary: Pull Request resolved: #1900 # context * previously `KTRegroupAsDict` can't really supported by torch.export (IR) because this module has an intialization step as running the first batch. * during the export the `KTRegroupAsDict` module will be initialized by a fake_tensor which is wrong * if we initialize the module before torch.export, the device would be an issue. * another issue is that current torch.export [can't support conditional logic in training](https://pytorch.org/docs/stable/cond.html), where initialization step only runs once. > torch.cond is a prototype feature in PyTorch. It has limited support for input and output types and doesn’t support training currently. Please look forward to a more stable implementation in a future version of PyTorch. NOTE: this is more like a workaround solution, real solution needs support from pytorch compile for conditional logic # details * we treat the `KTRegroupAsDict` as another sparse_arch and do the model swap before and after torch.export. * more context: D59019375 Reviewed By: PaulZhang12 Differential Revision: D56282744 fbshipit-source-id: b86f6eafa3d453735df6c9d00b33b16f70279dea
1 parent 8b7fef5 commit 591e076

File tree

4 files changed

+185
-17
lines changed

4 files changed

+185
-17
lines changed

torchrec/ir/schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,9 @@ class PositionWeightedModuleMetadata:
4848
@dataclass
4949
class PositionWeightedModuleCollectionMetadata:
5050
max_feature_lengths: List[Tuple[str, int]]
51+
52+
53+
@dataclass
54+
class KTRegroupAsDictMetadata:
55+
groups: List[List[str]]
56+
keys: List[str]

torchrec/ir/serializer.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
# pyre-strict
99

1010
import json
11-
from typing import Any, Dict, List, Optional, Tuple, Type
11+
from typing import Any, Dict, List, Optional, Type
1212

1313
import torch
14-
1514
from torch import nn
1615
from torchrec.ir.schema import (
1716
EBCMetadata,
1817
EmbeddingBagConfigMetadata,
1918
FPEBCMetadata,
19+
KTRegroupAsDictMetadata,
2020
PositionWeightedModuleCollectionMetadata,
2121
PositionWeightedModuleMetadata,
2222
)
@@ -32,6 +32,7 @@
3232
PositionWeightedModuleCollection,
3333
)
3434
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
35+
from torchrec.modules.regroup import KTRegroupAsDict
3536
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
3637

3738

@@ -96,7 +97,7 @@ def ebc_meta_forward(
9697
features.lengths_or_none(),
9798
features.offsets_or_none(),
9899
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
99-
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
100+
outputs = torch.ops.torchrec.ir_emb_lookup(arg_list, batch_size, dims)
100101
return KeyedTensor(
101102
keys=ebc._embedding_names,
102103
values=torch.cat(outputs, dim=1),
@@ -117,14 +118,30 @@ def fpebc_meta_forward(
117118
features.lengths_or_none(),
118119
features.offsets_or_none(),
119120
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
120-
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
121+
outputs = torch.ops.torchrec.ir_emb_lookup(arg_list, batch_size, dims)
121122
return KeyedTensor(
122123
keys=ebc._embedding_names,
123124
values=torch.cat(outputs, dim=1),
124125
length_per_key=ebc._lengths_per_embedding,
125126
)
126127

127128

129+
def kt_regroup_meta_forward(
130+
op_module: KTRegroupAsDict, keyed_tensors: List[KeyedTensor]
131+
) -> Dict[str, torch.Tensor]:
132+
lengths_dict: Dict[str, int] = {}
133+
batch_size = keyed_tensors[0].values().size(0)
134+
for kt in keyed_tensors:
135+
for key, length in zip(kt.keys(), kt.length_per_key()):
136+
lengths_dict[key] = length
137+
out_lengths: List[int] = [0] * len(op_module._groups)
138+
for i, group in enumerate(op_module._groups):
139+
out_lengths[i] = sum(lengths_dict[key] for key in group)
140+
arg_list = [kt.values() for kt in keyed_tensors]
141+
outputs = torch.ops.torchrec.ir_kt_regroup(arg_list, batch_size, out_lengths)
142+
return dict(zip(op_module._keys, outputs))
143+
144+
128145
class JsonSerializer(SerializerInterface):
129146
"""
130147
Serializer for torch.export IR using json.
@@ -364,3 +381,42 @@ def deserialize_from_dict(
364381
JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = (
365382
FPEBCJsonSerializer
366383
)
384+
385+
386+
class KTRegroupAsDictJsonSerializer(JsonSerializer):
387+
_module_cls = KTRegroupAsDict
388+
389+
@classmethod
390+
def swap_meta_forward(cls, module: nn.Module) -> None:
391+
assert isinstance(module, cls._module_cls)
392+
# pyre-ignore
393+
module.forward = kt_regroup_meta_forward.__get__(module, cls._module_cls)
394+
395+
@classmethod
396+
def serialize_to_dict(
397+
cls,
398+
module: nn.Module,
399+
) -> Dict[str, Any]:
400+
metadata = KTRegroupAsDictMetadata(
401+
keys=module._keys,
402+
groups=module._groups,
403+
)
404+
return metadata.__dict__
405+
406+
@classmethod
407+
def deserialize_from_dict(
408+
cls,
409+
metadata_dict: Dict[str, Any],
410+
device: Optional[torch.device] = None,
411+
unflatten_ep: Optional[nn.Module] = None,
412+
) -> nn.Module:
413+
metadata = KTRegroupAsDictMetadata(**metadata_dict)
414+
return KTRegroupAsDict(
415+
keys=metadata.keys,
416+
groups=metadata.groups,
417+
)
418+
419+
420+
JsonSerializer.module_to_serializer_cls["KTRegroupAsDict"] = (
421+
KTRegroupAsDictJsonSerializer
422+
)

torchrec/ir/tests/test_serializer.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
PositionWeightedModuleCollection,
3131
)
3232
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
33+
from torchrec.modules.regroup import KTRegroupAsDict
3334
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
3435

3536

@@ -295,7 +296,7 @@ def test_dynamic_shape_ebc(self) -> None:
295296
self.assertEqual(eager_out[i].shape, tensor.shape)
296297
assert torch.allclose(eager_out[i], tensor)
297298

298-
def test_ir_custom_op_device(self) -> None:
299+
def test_ir_emb_lookup_device(self) -> None:
299300
model = self.generate_model()
300301
model.fpebc1 = copy.deepcopy(model.ebc1)
301302
model.fpebc2 = copy.deepcopy(model.ebc1)
@@ -433,3 +434,90 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
433434
self.assertEqual(len(deserialized_out), len(eager_out))
434435
for x, y in zip(deserialized_out, eager_out):
435436
self.assertTrue(torch.allclose(x, y))
437+
438+
def test_regroup_as_dict_module(self) -> None:
439+
class Model(nn.Module):
440+
def __init__(self, ebc, fpebc, regroup):
441+
super().__init__()
442+
self.ebc = ebc
443+
self.fpebc = fpebc
444+
self.regroup = regroup
445+
446+
def forward(
447+
self,
448+
features: KeyedJaggedTensor,
449+
) -> Dict[str, torch.Tensor]:
450+
kt1 = self.ebc(features)
451+
kt2 = self.fpebc(features)
452+
return self.regroup([kt1, kt2])
453+
454+
tb1_config = EmbeddingBagConfig(
455+
name="t1",
456+
embedding_dim=3,
457+
num_embeddings=10,
458+
feature_names=["f1", "f2"],
459+
)
460+
tb2_config = EmbeddingBagConfig(
461+
name="t2",
462+
embedding_dim=4,
463+
num_embeddings=10,
464+
feature_names=["f3", "f4"],
465+
)
466+
tb3_config = EmbeddingBagConfig(
467+
name="t3",
468+
embedding_dim=5,
469+
num_embeddings=10,
470+
feature_names=["f5"],
471+
)
472+
473+
ebc = EmbeddingBagCollection(
474+
tables=[tb1_config, tb3_config],
475+
is_weighted=False,
476+
)
477+
max_feature_lengths = {"f3": 100, "f4": 100}
478+
fpebc = FeatureProcessedEmbeddingBagCollection(
479+
EmbeddingBagCollection(
480+
tables=[tb2_config],
481+
is_weighted=True,
482+
),
483+
PositionWeightedModuleCollection(
484+
max_feature_lengths=max_feature_lengths,
485+
),
486+
)
487+
regroup = KTRegroupAsDict([["f1", "f3", "f5"], ["f2", "f4"]], ["odd", "even"])
488+
model = Model(ebc, fpebc, regroup)
489+
490+
id_list_features = KeyedJaggedTensor.from_offsets_sync(
491+
keys=["f1", "f2", "f3", "f4", "f5"],
492+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
493+
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
494+
)
495+
self.assertFalse(model.regroup._is_inited)
496+
497+
# Serialize EBC
498+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
499+
ep = torch.export.export(
500+
model,
501+
(id_list_features,),
502+
{},
503+
strict=False,
504+
# Allows KJT to not be unflattened and run a forward on unflattened EP
505+
preserve_module_call_signature=(tuple(sparse_fqns)),
506+
)
507+
508+
self.assertFalse(model.regroup._is_inited)
509+
eager_out = model(id_list_features)
510+
self.assertFalse(model.regroup._is_inited)
511+
512+
# Run forward on ExportedProgram
513+
ep_output = ep.module()(id_list_features)
514+
for key in eager_out.keys():
515+
self.assertEqual(ep_output[key].shape, eager_out[key].shape)
516+
# Deserialize EBC
517+
unflatten_ep = torch.export.unflatten(ep)
518+
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
519+
self.assertFalse(deserialized_model.regroup._is_inited)
520+
deserialized_out = deserialized_model(id_list_features)
521+
self.assertTrue(deserialized_model.regroup._is_inited)
522+
for key in eager_out.keys():
523+
self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)

torchrec/ir/utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,42 +38,60 @@ def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]:
3838
return None
3939

4040

41-
@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
42-
def ir_custom_op_impl(
41+
@torch.library.custom_op("torchrec::ir_emb_lookup", mutates_args={})
42+
def ir_emb_lookup_impl(
4343
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
4444
) -> List[torch.Tensor]:
4545
device = get_device(tensors)
46-
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dims}) {device}")
46+
logger.info(f"torch.ops.torchrec.ir_emb_lookup -> ({batch_size}, {dims}) {device}")
4747
return [torch.empty(batch_size, dim, device=device) for dim in dims]
4848

4949

50-
@torch.library.register_fake("torchrec::ir_custom_op")
51-
def ir_custom_op_fake(
50+
@torch.library.register_fake("torchrec::ir_emb_lookup")
51+
def ir_emb_lookup_fake(
5252
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
5353
) -> List[torch.Tensor]:
5454
device = get_device(tensors)
55-
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dims}) {device}")
55+
logger.info(f"ir_emb_lookup_fake -> ({batch_size}, {dims}) {device}")
5656
return [torch.empty(batch_size, dim, device=device) for dim in dims]
5757

5858

59-
@torch.library.custom_op("torchrec::ir_dynamic_batch_op", mutates_args={})
60-
def ir_dynamic_batch_op_impl(
59+
@torch.library.custom_op("torchrec::ir_kt_regroup", mutates_args={})
60+
def ir_kt_regroup_impl(
61+
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
62+
) -> List[torch.Tensor]:
63+
device = get_device(tensors)
64+
logger.info(f"torch.ops.torchrec.ir_kt_regroup -> ({batch_size}, {dims}) {device}")
65+
return [torch.empty(batch_size, dim, device=device) for dim in dims]
66+
67+
68+
@torch.library.register_fake("torchrec::ir_kt_regroup")
69+
def ir_kt_regroup_fake(
70+
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
71+
) -> List[torch.Tensor]:
72+
device = get_device(tensors)
73+
logger.info(f"ir_kt_regroup_fake -> ({batch_size}, {dims}) {device}")
74+
return [torch.empty(batch_size, dim, device=device) for dim in dims]
75+
76+
77+
@torch.library.custom_op("torchrec::ir_dynamic_batch_emb_lookup", mutates_args={})
78+
def ir_dynamic_batch_emb_lookup_impl(
6179
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
6280
) -> List[torch.Tensor]:
6381
device = get_device(tensors)
6482
logger.info(
65-
f"torch.ops.torchrec.ir_dynamic_batch_op -> ({batch_size}, {dims}) {device}"
83+
f"torch.ops.torchrec.ir_dynamic_batch_emb_lookup -> ({batch_size}, {dims}) {device}"
6684
)
6785
return [torch.empty(batch_size, dim, device=device) for dim in dims]
6886

6987

70-
@torch.library.register_fake("torchrec::ir_dynamic_batch_op")
71-
def ir_dynamic_batch_op_fake(
88+
@torch.library.register_fake("torchrec::ir_dynamic_batch_emb_lookup")
89+
def ir_dynamic_batch_emb_lookup_fake(
7290
tensors: List[Optional[torch.Tensor]], batch_dize: int, dims: List[int]
7391
) -> List[torch.Tensor]:
7492
device = get_device(tensors)
7593
batch_size = torch.library.get_ctx().new_dynamic_size()
76-
logger.info(f"ir_dynamic_batch_op_fake -> ({batch_size}, {dims}) {device}")
94+
logger.info(f"ir_dynamic_batch_emb_lookup_fake -> ({batch_size}, {dims}) {device}")
7795
return [torch.empty(batch_size, dim, device=device) for dim in dims]
7896

7997

0 commit comments

Comments
 (0)