diff --git a/examples/export/test/TARGETS b/examples/export/test/TARGETS index c440f9a2037..1acac1cb034 100644 --- a/examples/export/test/TARGETS +++ b/examples/export/test/TARGETS @@ -10,5 +10,6 @@ python_unittest( "//executorch/examples/export:utils", "//executorch/examples/models:models", "//executorch/exir:lib", + "//executorch/extension/pybindings:portable", # @manual ], ) diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index 64bb69f57b7..35655b9dfd2 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -8,9 +8,14 @@ import torch -from executorch.examples.export.utils import _EDGE_COMPILE_CONFIG +from executorch.examples.export.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG from executorch.examples.models import MODEL_NAME_TO_MODEL +# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. +from executorch.extension.pybindings.portable import ( # @manual + _load_for_executorch_from_buffer, +) + class ExportTest(unittest.TestCase): def _assert_eager_lowered_same_result( @@ -18,16 +23,18 @@ def _assert_eager_lowered_same_result( ): import executorch.exir as exir - capture_config = exir.CaptureConfig(enable_dynamic_shape=False) - edge_model = exir.capture(eager_model, example_inputs, capture_config).to_edge( + edge_model = exir.capture(eager_model, example_inputs, _CAPTURE_CONFIG).to_edge( _EDGE_COMPILE_CONFIG ) executorch_model = edge_model.to_executorch() + # pyre-ignore + pte_model = _load_for_executorch_from_buffer(executorch_model.buffer) + with torch.no_grad(): eager_output = eager_model(*example_inputs) with torch.no_grad(): - executorch_output = executorch_model.graph_module(*example_inputs) + executorch_output = pte_model.forward(example_inputs) self.assertTrue( torch.allclose(eager_output, executorch_output[0], rtol=1e-5, atol=1e-5) ) diff --git a/examples/export/utils.py b/examples/export/utils.py index 1a38acf984a..0255f64509f 100644 --- a/examples/export/utils.py +++ b/examples/export/utils.py @@ -11,7 +11,7 @@ # Reason is that there memory allocation ops with symbolic shape nodes. # and when evaulating shape, it doesnt seem that we presenting them with shape env # that contain those variables. -_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True, _unlift=False) +_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True) _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( _check_ir_validity=False, ) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 7e008268b06..2b23e1d9f81 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -19,8 +19,8 @@ class CaptureConfig: pt2_mode: bool = True enable_functionalization: bool = True - enable_dynamic_shape: bool = False - enable_aot: bool = False + enable_dynamic_shape: bool = False # This flag does nothing if enable_aot is True + enable_aot: bool = False # When it's true it implies automatic dynamic shapes via default dynamo config _dynamo_config: "ExirDynamoConfig" = field(default_factory=ExirDynamoConfig) _unlift: bool = False _use_old_decomp_table: bool = False