From 60b6d512a0029b2db3394a8666eedd131d7784f1 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 9 Apr 2025 09:02:06 -0700 Subject: [PATCH] [ez][release blocker fix] Insert `linalg_vector_norm` into decomp table used for Edge export Summary: ## Context Addresses this [release blocker](https://github.com/orgs/pytorch/projects/99/views/1?pane=issue&itemId=104088363&issue=pytorch%7Cpytorch%7C150207) issue. Some models cannot export because they use `linalg_vector_norm` which is not currently an ATen operator. I initially tried adding the op to the core decomp table, but the decomp is not passing pytorch correctness tests. Please see https://github.com/pytorch/pytorch/pull/150241 for more details. ## Changes Since we currently cannot include the op in PyTorch's decomp table, instead we can insert the op into the edge decomp table directly. This PR is a simple change to add `linalg_vector_norm` directly to the edge decomp table. Test Plan: Tested exporting and running a model with the `linalg_vector_norm` op via the following script. ``` import torch from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from torch.export import Dim, export from executorch.extension.pybindings.portable_lib import ( # @manual _load_for_executorch_from_buffer, ) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.linalg.vector_norm(x, 2) model = Model() inputs = (torch.randn(1,1,16,16),) dynamic_shapes = { "x": { 2: Dim("h", min=16, max=1024), 3: Dim("w", min=16, max=1024), } } exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes) executorch_program = to_edge_transform_and_lower( exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False), ).to_executorch() executorch_module = _load_for_executorch_from_buffer( executorch_program.buffer ) model_output = executorch_module.run_method( "forward", tuple(inputs) ) print(model_output) ``` --- exir/program/test/test_program.py | 12 +++++------- exir/tracer.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index fca8bd2212f..fd7c51fdccb 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -725,17 +725,17 @@ def count_nodes(graph_module, target): ) def test_edge_dialect_non_core_aten_ops(self): - class LinalgNorm(torch.nn.Module): + class LinalgRank(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.linalg.norm(x) + return torch.linalg.matrix_rank(x) from torch._export.verifier import SpecViolationError - input = torch.arange(9, dtype=torch.float) - 4 - ep = torch.export.export(LinalgNorm(), (input,), strict=True) + input = torch.ones((9, 9, 9), dtype=torch.float) + ep = torch.export.export(LinalgRank(), (input,), strict=True) # aten::linalg_norm is not a core op, so it should error out with self.assertRaises(SpecViolationError): @@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ep, compile_config=EdgeCompileConfig( _check_ir_validity=True, - _core_aten_ops_exception_list=[ - torch.ops.aten.linalg_vector_norm.default - ], + _core_aten_ops_exception_list=[torch.ops.aten._linalg_svd.default], ), ) except SpecViolationError: diff --git a/exir/tracer.py b/exir/tracer.py index 82f93424a14..c749df510ad 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -631,8 +631,18 @@ def _default_decomposition_table( ] # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... return get_decompositions(decomp_opset) + + decomps = default_decompositions() + # Add edge specific decompositions + additional_decomp_ops = [ + # TODO: Eventually this op should be added to the core decompo table, and will not + # need to be added here. + torch.ops.aten.linalg_vector_norm.default, + ] + additional_decomps = get_decompositions(additional_decomp_ops) + decomps.update(additional_decomps) # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... - return default_decompositions() + return decomps def dynamo_trace(