Skip to content

Simplify custom op naming for meta functionalization of sparse modules #1974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _non_strict_exporting_forward(
features.offsets_or_none(),
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
dims = [sum(self._lengths_per_embedding)]
ebc_op = register_custom_op(self, dims)
ebc_op = register_custom_op(type(self).__name__, dims)
outputs = ebc_op(arg_list, batch_size)
return KeyedTensor(
keys=self._embedding_names,
Expand Down
112 changes: 64 additions & 48 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ class OpRegistryState:

# operator schema: {class}.{id} => op_name
op_registry_schema: Dict[str, str] = {}
# operator counter: {class} => count
op_registry_counter: Dict[str, int] = defaultdict(int)


operator_registry_state = OpRegistryState()
Expand Down Expand Up @@ -274,7 +272,8 @@ def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
# a list of tensors as output. The operator is registered with the name of
# {module_class_name}_{instance_count}
def register_custom_op(
module: torch.nn.Module, dims: List[int]
module_name: str,
dims: List[int],
) -> Callable[[List[Optional[torch.Tensor]], int], List[torch.Tensor]]:
"""
Register a customized operator.
Expand All @@ -286,51 +285,68 @@ def register_custom_op(

global operator_registry_state

m_name: str = type(module).__name__
op_id: str = f"{m_name}_{id(module)}"
dims_str = "_".join([str(d) for d in dims])
with operator_registry_state.op_registry_lock:
if op_id in operator_registry_state.op_registry_schema:
op_name: str = operator_registry_state.op_registry_schema[op_id]
else:
operator_registry_state.op_registry_counter[m_name] += 1
op_name: str = (
f"{m_name}_{operator_registry_state.op_registry_counter[m_name]}"
)
operator_registry_state.op_registry_schema[op_id] = op_name

def custom_op(
values: List[Optional[torch.Tensor]],
batch_size: int,
) -> List[torch.Tensor]:
device = None
for v in values:
if v is not None:
device = v.device
break
else:
raise AssertionError(
f"Custom op {op_name} expects at least one input tensor"
)

return [
torch.empty(
batch_size,
dim,
device=device,
)
for dim in dims
]

schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
operator_registry_state.op_registry_schema[op_name] = schema_string
# Register schema
lib.define(schema_string)

# Register implementation
lib.impl(op_name, custom_op, "CPU")
lib.impl(op_name, custom_op, "CUDA")

# Register meta formula
lib.impl(op_name, custom_op, "Meta")
op_name: str = f"{module_name}_{dims_str}"

if op_name in operator_registry_state.op_registry_schema:
return getattr(torch.ops.custom, op_name)

def custom_op(
values: List[Optional[torch.Tensor]],
batch_size: int,
) -> List[torch.Tensor]:
device = None
for v in values:
if v is not None:
device = v.device
break
else:
raise AssertionError(
f"Custom op {op_name} expects at least one input tensor"
)

return [
torch.empty(
batch_size,
dim,
device=device,
)
for dim in dims
]

schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
operator_registry_state.op_registry_schema[op_name] = schema_string
# Register schema
lib.define(schema_string)

# Register implementation
lib.impl(op_name, custom_op, "CPU")
lib.impl(op_name, custom_op, "CUDA")

# Register meta formula
lib.impl(op_name, custom_op, "Meta")

return getattr(torch.ops.custom, op_name)


def register_custom_ops_for_nodes(
nodes: List[torch.fx.Node],
) -> None:
"""
Given a list of nodes, register custom ops if they exist in the nodes.
Required for deserialization if in different runtime environments

Args:
nodes: list of nodes
"""

for node in nodes:
if "custom." in str(node.target):
# torch.ops.custom.EmbeddingBagCollection_100.default
# number represents dimension
op_name = str(node.target).split(".")[-2]
register_custom_op(
op_name.split("_")[0],
[int(dim) for dim in op_name.split("_")[1:]],
)