Skip to content

Add torchvision_vit model to the examples #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/export/test/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def _assert_eager_lowered_same_result(
_EDGE_COMPILE_CONFIG
)

executorch_model = edge_model.to_executorch()
executorch_prog = edge_model.to_executorch()
# pyre-ignore
pte_model = _load_for_executorch_from_buffer(executorch_model.buffer)
pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer)

with torch.no_grad():
eager_output = eager_model(*example_inputs)
with torch.no_grad():
executorch_output = pte_model.forward(example_inputs)
executorch_output = pte_model.run_method("forward", example_inputs)

if isinstance(eager_output, tuple):
# TODO: Allow validating other items
Expand Down Expand Up @@ -65,3 +65,9 @@ def test_emformer_export_to_executorch(self):
eager_model = eager_model.eval()

self._assert_eager_lowered_same_result(eager_model, example_inputs)

def test_vit_export_to_executorch(self):
eager_model, example_inputs = MODEL_NAME_TO_MODEL["vit"]()
eager_model = eager_model.eval()

self._assert_eager_lowered_same_result(eager_model, example_inputs)
1 change: 1 addition & 0 deletions examples/models/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ python_library(
"//executorch/examples/models/emformer:emformer_export",
"//executorch/examples/models/mobilenet_v2:mv2_export",
"//executorch/examples/models/mobilenet_v3:mv3_export",
"//executorch/examples/models/torchvision_vit:vit_export",
"//executorch/exir/backend:compile_spec_schema",
],
)
7 changes: 7 additions & 0 deletions examples/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def gen_emformer_model_inputs() -> Tuple[torch.nn.Module, Any]:
return EmformerModel.get_model(), EmformerModel.get_example_inputs()


def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
from ..models.torchvision_vit import TorchVisionViTModel

return TorchVisionViTModel.get_model(), TorchVisionViTModel.get_example_inputs()


MODEL_NAME_TO_MODEL = {
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
Expand All @@ -103,4 +109,5 @@ def gen_emformer_model_inputs() -> Tuple[torch.nn.Module, Any]:
"mv2": gen_mobilenet_v2_model_inputs,
"mv3": gen_mobilenet_v3_model_inputs,
"emformer": gen_emformer_model_inputs,
"vit": gen_torchvision_vit_model_and_inputs,
}
14 changes: 14 additions & 0 deletions examples/models/torchvision_vit/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "vit_export",
srcs = [
"__init__.py",
"export.py",
],
base_module = "executorch.examples.models.torchvision_vit",
deps = [
"//caffe2:torch",
"//pytorch/vision:torchvision",
],
)
11 changes: 11 additions & 0 deletions examples/models/torchvision_vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .export import TorchVisionViTModel

__all__ = [
TorchVisionViTModel,
]
30 changes: 30 additions & 0 deletions examples/models/torchvision_vit/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from torchvision import models

FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=FORMAT)


class TorchVisionViTModel:
def __init__(self):
pass

@staticmethod
def get_model():
logging.info("loading torchvision vit_b_16 model")
vit_b_16 = models.vit_b_16(weights="IMAGENET1K_V1")
logging.info("loaded torchvision vit_b_16 model")
return vit_b_16

@staticmethod
def get_example_inputs():
input_shape = (1, 3, 224, 224)
return (torch.randn(input_shape),)
2 changes: 1 addition & 1 deletion extension/pybindings/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ py::object pyFromEValue(const EValue& v, KeepAlive& keep_alive) {
}

static constexpr size_t kDEFAULT_NON_CONSTANT_POOL_SIZE =
256 * 1024U * 1024U; // 256 MB
2 * 256 * 1024U * 1024U; // 512 MB
static constexpr size_t kRUNTIME_POOL_SIZE = 256 * 1024U * 1024U; // 256 MB
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;

Expand Down