diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index fca8bd2212f..fd7c51fdccb 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -725,17 +725,17 @@ def count_nodes(graph_module, target): ) def test_edge_dialect_non_core_aten_ops(self): - class LinalgNorm(torch.nn.Module): + class LinalgRank(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.linalg.norm(x) + return torch.linalg.matrix_rank(x) from torch._export.verifier import SpecViolationError - input = torch.arange(9, dtype=torch.float) - 4 - ep = torch.export.export(LinalgNorm(), (input,), strict=True) + input = torch.ones((9, 9, 9), dtype=torch.float) + ep = torch.export.export(LinalgRank(), (input,), strict=True) # aten::linalg_norm is not a core op, so it should error out with self.assertRaises(SpecViolationError): @@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ep, compile_config=EdgeCompileConfig( _check_ir_validity=True, - _core_aten_ops_exception_list=[ - torch.ops.aten.linalg_vector_norm.default - ], + _core_aten_ops_exception_list=[torch.ops.aten._linalg_svd.default], ), ) except SpecViolationError: diff --git a/exir/tracer.py b/exir/tracer.py index 82f93424a14..c749df510ad 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -631,8 +631,18 @@ def _default_decomposition_table( ] # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... return get_decompositions(decomp_opset) + + decomps = default_decompositions() + # Add edge specific decompositions + additional_decomp_ops = [ + # TODO: Eventually this op should be added to the core decompo table, and will not + # need to be added here. + torch.ops.aten.linalg_vector_norm.default, + ] + additional_decomps = get_decompositions(additional_decomp_ops) + decomps.update(additional_decomps) # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... - return default_decompositions() + return decomps def dynamo_trace(