diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index cbb980f7a1f..cb362c2afd0 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -35,6 +35,7 @@ from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.serialize import serialize_to_flatbuffer +from executorch.exir.tracer import _default_decomposition_table # pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual @@ -312,11 +313,9 @@ def quantize_and_test_model_with_quantizer( ): module.eval() # program capture - capture_config = exir.CaptureConfig( - pt2_mode=True, enable_functionalization=True + m = torch._export.capture_pre_autograd_graph( + module, example_inputs, decomp_table=_default_decomposition_table() ) - captured_program = exir.capture(module, example_inputs, config=capture_config) - m = captured_program.exported_program.graph_module quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config() @@ -324,7 +323,12 @@ def quantize_and_test_model_with_quantizer( prepared = prepare_pt2e(m, quantizer) converted = convert_pt2e(prepared) - captured_program.exported_program.graph_module = converted + captured_program = exir.capture( + converted, + example_inputs, + config=exir.CaptureConfig(enable_aot=True, _unlift=True), + ) + edge_program = captured_program.to_edge(get_xnnpack_edge_compile_config()) delegated_module = self.lower_module_and_test_output( module=edge_program, diff --git a/exir/capture/TARGETS b/exir/capture/TARGETS index 2c4c35193e6..ad49025c5e7 100644 --- a/exir/capture/TARGETS +++ b/exir/capture/TARGETS @@ -8,7 +8,6 @@ python_library( deps = [ ":capture", ":config", - ":unlift", ], ) @@ -19,7 +18,6 @@ python_library( ], deps = [ ":config", - ":unlift", "//caffe2:torch", "//executorch/exir:error", "//executorch/exir:tracer", @@ -40,13 +38,3 @@ python_library( "//executorch/exir/passes:lib", ], ) - -python_library( - name = "unlift", - srcs = [ - "_unlift.py", - ], - deps = [ - "//caffe2:torch", - ], -) diff --git a/exir/capture/__init__.py b/exir/capture/__init__.py index c39bcba5e1e..ca7277d3272 100644 --- a/exir/capture/__init__.py +++ b/exir/capture/__init__.py @@ -12,7 +12,6 @@ EdgeCompileConfig, ExecutorchBackendConfig, ) -from executorch.exir.capture._unlift import unlift_exported_program_lifted_states __all__ = [ "capture", @@ -20,5 +19,4 @@ "CaptureConfig", "EdgeCompileConfig", "ExecutorchBackendConfig", - "unlift_exported_program_lifted_states", ] diff --git a/exir/capture/_capture.py b/exir/capture/_capture.py index 2c6529c2158..623ca260a29 100644 --- a/exir/capture/_capture.py +++ b/exir/capture/_capture.py @@ -13,7 +13,6 @@ import torch import torch._export from executorch.exir.capture._config import CaptureConfig -from executorch.exir.capture._unlift import unlift_exported_program_lifted_states from executorch.exir.error import ExportError, ExportErrorType, InternalError from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram from executorch.exir.tracer import ( @@ -75,7 +74,7 @@ def capture( ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass()) if not config._unlift: return ExirExportedProgram(ep, False) - graph_module = unlift_exported_program_lifted_states(ep) + graph_module = ep.module() elif config.enable_dynamic_shape: graph_module, _ = dynamo_trace( diff --git a/exir/delegate.py b/exir/delegate.py index 8c164facdb9..d07b38a1bd1 100644 --- a/exir/delegate.py +++ b/exir/delegate.py @@ -107,8 +107,9 @@ def call_delegate_autograd(lowered_module, *args): def fake_requires_grad(var): if var is not None: var = var.detach() - var.requires_grad = True - return err_fn(var) + if torch.is_floating_point(var) or torch.is_complex(var): + var.requires_grad = True + return var return pytree.tree_map(fake_requires_grad, res)