Skip to content

Commit ea9b83c

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
register a custom op to keep the PEA module unflattened when torch.export (pytorch#1900)
Summary: reference: * D54009459 Differential Revision: D56282744
1 parent f120e42 commit ea9b83c

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

torchrec/modules/embedding_modules.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
# pyre-strict
99

1010
import abc
11+
import threading
1112
from typing import Dict, List, Optional, Tuple, Union
1213

1314
import torch
1415
import torch.nn as nn
16+
from torch.library import Library
1517
from torchrec.modules.embedding_configs import (
1618
DataType,
1719
EmbeddingBagConfig,
@@ -20,6 +22,24 @@
2022
)
2123
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
2224

25+
lib = Library("custom", "FRAGMENT")
26+
27+
28+
class OpRegistryState:
29+
"""
30+
State of operator registry.
31+
32+
We can only register the op schema once. So if we're registering multiple
33+
times we need a lock and check if they're the same schema
34+
"""
35+
36+
op_registry_lock = threading.Lock()
37+
# operator schema: op_name: schema
38+
op_registry_schema: Dict[str, str] = {}
39+
40+
41+
operator_registry_state = OpRegistryState()
42+
2343

2444
@torch.fx.wrap
2545
def reorder_inverse_indices(

0 commit comments

Comments
 (0)