From 762259b537595064b7a75a5abc25f3a736e8fc5f Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Fri, 11 Aug 2023 14:21:37 -0700 Subject: [PATCH] Add torchvision_vit model to the examples (#33) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/33 Info about the the model: https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16 Reviewed By: kimishpatel, kirklandsign Differential Revision: D48012005 fbshipit-source-id: fe8d8006c2936e6c9e04e9222cefac167edd8705 --- examples/export/test/test_export.py | 12 ++++++--- examples/models/TARGETS | 1 + examples/models/models.py | 7 +++++ examples/models/torchvision_vit/TARGETS | 14 ++++++++++ examples/models/torchvision_vit/__init__.py | 11 ++++++++ examples/models/torchvision_vit/export.py | 30 +++++++++++++++++++++ extension/pybindings/module.cpp | 2 +- 7 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 examples/models/torchvision_vit/TARGETS create mode 100644 examples/models/torchvision_vit/__init__.py create mode 100644 examples/models/torchvision_vit/export.py diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index bb6a8f286fe..737d4d76811 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -27,14 +27,14 @@ def _assert_eager_lowered_same_result( _EDGE_COMPILE_CONFIG ) - executorch_model = edge_model.to_executorch() + executorch_prog = edge_model.to_executorch() # pyre-ignore - pte_model = _load_for_executorch_from_buffer(executorch_model.buffer) + pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer) with torch.no_grad(): eager_output = eager_model(*example_inputs) with torch.no_grad(): - executorch_output = pte_model.forward(example_inputs) + executorch_output = pte_model.run_method("forward", example_inputs) if isinstance(eager_output, tuple): # TODO: Allow validating other items @@ -65,3 +65,9 @@ def test_emformer_export_to_executorch(self): eager_model = eager_model.eval() self._assert_eager_lowered_same_result(eager_model, example_inputs) + + def test_vit_export_to_executorch(self): + eager_model, example_inputs = MODEL_NAME_TO_MODEL["vit"]() + eager_model = eager_model.eval() + + self._assert_eager_lowered_same_result(eager_model, example_inputs) diff --git a/examples/models/TARGETS b/examples/models/TARGETS index ee6ae9bee77..26c4c1b8a85 100644 --- a/examples/models/TARGETS +++ b/examples/models/TARGETS @@ -11,6 +11,7 @@ python_library( "//executorch/examples/models/emformer:emformer_export", "//executorch/examples/models/mobilenet_v2:mv2_export", "//executorch/examples/models/mobilenet_v3:mv3_export", + "//executorch/examples/models/torchvision_vit:vit_export", "//executorch/exir/backend:compile_spec_schema", ], ) diff --git a/examples/models/models.py b/examples/models/models.py index ae229ee86c2..a3e8b41bcaf 100644 --- a/examples/models/models.py +++ b/examples/models/models.py @@ -95,6 +95,12 @@ def gen_emformer_model_inputs() -> Tuple[torch.nn.Module, Any]: return EmformerModel.get_model(), EmformerModel.get_example_inputs() +def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]: + from ..models.torchvision_vit import TorchVisionViTModel + + return TorchVisionViTModel.get_model(), TorchVisionViTModel.get_example_inputs() + + MODEL_NAME_TO_MODEL = { "mul": lambda: (MulModule(), MulModule.get_example_inputs()), "linear": lambda: (LinearModule(), LinearModule.get_example_inputs()), @@ -103,4 +109,5 @@ def gen_emformer_model_inputs() -> Tuple[torch.nn.Module, Any]: "mv2": gen_mobilenet_v2_model_inputs, "mv3": gen_mobilenet_v3_model_inputs, "emformer": gen_emformer_model_inputs, + "vit": gen_torchvision_vit_model_and_inputs, } diff --git a/examples/models/torchvision_vit/TARGETS b/examples/models/torchvision_vit/TARGETS new file mode 100644 index 00000000000..6664752b75d --- /dev/null +++ b/examples/models/torchvision_vit/TARGETS @@ -0,0 +1,14 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "vit_export", + srcs = [ + "__init__.py", + "export.py", + ], + base_module = "executorch.examples.models.torchvision_vit", + deps = [ + "//caffe2:torch", + "//pytorch/vision:torchvision", + ], +) diff --git a/examples/models/torchvision_vit/__init__.py b/examples/models/torchvision_vit/__init__.py new file mode 100644 index 00000000000..723193abc6c --- /dev/null +++ b/examples/models/torchvision_vit/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .export import TorchVisionViTModel + +__all__ = [ + TorchVisionViTModel, +] diff --git a/examples/models/torchvision_vit/export.py b/examples/models/torchvision_vit/export.py new file mode 100644 index 00000000000..d407eb45ffe --- /dev/null +++ b/examples/models/torchvision_vit/export.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torchvision import models + +FORMAT = "[%(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(format=FORMAT) + + +class TorchVisionViTModel: + def __init__(self): + pass + + @staticmethod + def get_model(): + logging.info("loading torchvision vit_b_16 model") + vit_b_16 = models.vit_b_16(weights="IMAGENET1K_V1") + logging.info("loaded torchvision vit_b_16 model") + return vit_b_16 + + @staticmethod + def get_example_inputs(): + input_shape = (1, 3, 224, 224) + return (torch.randn(input_shape),) diff --git a/extension/pybindings/module.cpp b/extension/pybindings/module.cpp index fbfeef42431..c811e07fcb0 100644 --- a/extension/pybindings/module.cpp +++ b/extension/pybindings/module.cpp @@ -283,7 +283,7 @@ py::object pyFromEValue(const EValue& v, KeepAlive& keep_alive) { } static constexpr size_t kDEFAULT_NON_CONSTANT_POOL_SIZE = - 256 * 1024U * 1024U; // 256 MB + 2 * 256 * 1024U * 1024U; // 512 MB static constexpr size_t kRUNTIME_POOL_SIZE = 256 * 1024U * 1024U; // 256 MB static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;