Skip to content

Add Inception V3 model to examples #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/export/test/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,11 @@ def test_w2l_export_to_executorch(self):
self._assert_eager_lowered_same_result(
eager_model, example_inputs, self.validate_tensor_allclose
)

def test_ic3_export_to_executorch(self):
eager_model, example_inputs = MODEL_NAME_TO_MODEL["ic3"]()
eager_model = eager_model.eval()

self._assert_eager_lowered_same_result(
eager_model, example_inputs, self.validate_tensor_allclose
)
1 change: 1 addition & 0 deletions examples/models/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ python_library(
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/inception_v3:ic3_export",
"//executorch/examples/models/mobilenet_v2:mv2_export",
"//executorch/examples/models/mobilenet_v3:mv3_export",
"//executorch/examples/models/torchvision_vit:vit_export",
Expand Down
14 changes: 14 additions & 0 deletions examples/models/inception_v3/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "ic3_export",
srcs = [
"__init__.py",
"export.py",
],
base_module = "executorch.examples.models.inception_v3",
deps = [
"//caffe2:torch",
"//pytorch/vision:torchvision",
],
)
11 changes: 11 additions & 0 deletions examples/models/inception_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .export import InceptionV3Model

__all__ = [
InceptionV3Model,
]
30 changes: 30 additions & 0 deletions examples/models/inception_v3/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from torchvision import models

FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=FORMAT)


class InceptionV3Model:
def __init__(self):
pass

@staticmethod
def get_model():
logging.info("loading torchvision inception_v3 model")
inception_v3 = models.inception_v3(weights="IMAGENET1K_V1")
logging.info("loaded torchvision inception_v3 model")
return inception_v3

@staticmethod
def get_example_inputs():
input_shape = (1, 3, 224, 224)
return (torch.randn(input_shape),)
7 changes: 7 additions & 0 deletions examples/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
return model.get_model(), model.get_example_inputs()


def gen_inception_v3_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
from ..models.inception_v3 import InceptionV3Model

return InceptionV3Model.get_model(), InceptionV3Model.get_example_inputs()


MODEL_NAME_TO_MODEL = {
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
Expand All @@ -111,4 +117,5 @@ def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
"mv3": gen_mobilenet_v3_model_inputs,
"vit": gen_torchvision_vit_model_and_inputs,
"w2l": gen_wav2letter_model_and_inputs,
"ic3": gen_inception_v3_model_and_inputs,
}
114 changes: 114 additions & 0 deletions kernels/portable/cpu/op_avg_pool2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cstring>

#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;
using IntArrayRef = exec_aten::ArrayRef<int64_t>;

Tensor& avg_pool2d_out(
RuntimeContext& ctx,
const Tensor& in,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
exec_aten::optional<int64_t> divisor_override,
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
check_avg_pool2d_args(
in,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
out),
InvalidArgument,
out);

size_t output_ndim = 0;
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
get_avg_pool2d_out_target_size(
in, kernel_size, stride, padding, ceil_mode, output_sizes, &output_ndim);

ET_KERNEL_CHECK(
ctx,
output_size_is_valid({output_sizes, output_ndim}),
InvalidArgument,
out);

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
InvalidArgument,
out);

ScalarType in_type = in.scalar_type();
ET_SWITCH_FLOAT_TYPES_AND(Long, in_type, ctx, __func__, CTYPE, [&]() {
if (divisor_override.has_value()) {
int64_t divisor = divisor_override.value();
// If divisor_override is specified, then we don't need to use `count` in
// the calculation. Simply sum x / divisor to get the output.
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
int64_t in_idx,
CTYPE accum,
int64_t accum_idx) {
// Average pooling does not track indexes, so return 0 for accum_idx
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
},
[divisor](const int64_t count, const CTYPE accum) {
return accum / static_cast<CTYPE>(divisor);
},
count_include_pad,
in,
kernel_size,
stride,
padding,
{},
out);
} else {
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
int64_t in_idx,
CTYPE accum,
int64_t accum_idx) {
// Average pooling does not track indexes, so return 0 for accum_idx
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
},
[](const int64_t count, const CTYPE accum) {
return accum / static_cast<CTYPE>(count);
},
count_include_pad,
in,
kernel_size,
stride,
padding,
{},
out);
}
});

return out;
}

} // namespace native
} // namespace executor
} // namespace torch
13 changes: 9 additions & 4 deletions kernels/portable/cpu/op_max_pool2d_with_indices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <cstring>

#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -55,7 +54,7 @@ std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(
ctx,
output_size_is_valid({output_sizes, output_ndim}),
InvalidArgument,
out);
ret_val);

ET_KERNEL_CHECK(
ctx,
Expand All @@ -71,13 +70,19 @@ std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(

ScalarType in_type = in.scalar_type();
ET_SWITCH_REAL_TYPES(in_type, ctx, __func__, CTYPE, [&]() {
apply_kernel_2d_reduce_fn<CTYPE>(
[](const CTYPE in_val, int64_t in_idx, CTYPE accum, int64_t accum_idx) {
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
const int64_t in_idx,
const CTYPE accum,
const int64_t accum_idx) {
if (in_val > accum) {
return std::tuple<CTYPE, int64_t>(in_val, in_idx);
}
return std::tuple<CTYPE, int64_t>(accum, accum_idx);
},
// Max pooling does not need to post-process the accumulated output
[](const int64_t count, const CTYPE accum) { return accum; },
/*include_pad=*/false,
in,
kernel_size,
stride,
Expand Down
6 changes: 6 additions & 0 deletions kernels/portable/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ _ATEN_OPS = (
"//executorch/kernels/portable/cpu/pattern:pattern",
],
),
op_target(
name = "op_avg_pool2d",
deps = [
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
],
),
op_target(
name = "op_bitwise_and",
deps = [
Expand Down
52 changes: 52 additions & 0 deletions kernels/portable/cpu/util/kernel_ops_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,58 @@ void calculate_kernel_output_sizes(
}
}

bool check_avg_pool2d_args(
const Tensor& in,
const IntArrayRef kernel_size,
const IntArrayRef stride,
const IntArrayRef padding,
const bool ceil_mode,
const bool count_include_pad,
const exec_aten::optional<int64_t>& divisor_override,
const Tensor& out) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));

ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));

ET_LOG_AND_RETURN_IF_FALSE(kernel_size_is_valid(kernel_size, 2));
if (stride.size() > 0) {
ET_LOG_AND_RETURN_IF_FALSE(stride_is_valid(kernel_size, 2));
}
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(padding, kernel_size, 2, true));

if (divisor_override.has_value()) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
divisor_override.value() > 0,
"divisor_override must be > 0, but found %" PRId64,
divisor_override.value());
}

return true;
}

void get_avg_pool2d_out_target_size(
const Tensor& in,
const IntArrayRef kernel_size,
const IntArrayRef stride,
const IntArrayRef padding,
const bool ceil_mode,
exec_aten::SizesType* const out_sizes,
size_t* const out_ndim) {
*out_ndim = in.dim();

// Batch dim is optional, so in can be either 3 or 4 dim.
if (in.dim() == 4) {
out_sizes[0] = in.size(0);
out_sizes[1] = in.size(1);
} else {
out_sizes[0] = in.size(0);
}

calculate_kernel_output_sizes(
in, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
}

bool check_convolution_args(
const Tensor& in,
const Tensor& weight,
Expand Down
Loading