Skip to content

Commit e2b42a9

Browse files
committed
Add portable adaptive avg pool2d kernel
1 parent 7b5c60d commit e2b42a9

File tree

8 files changed

+467
-0
lines changed

8 files changed

+467
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using Tensor = executorch::aten::Tensor;
19+
using ScalarType = executorch::aten::ScalarType;
20+
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;
21+
22+
namespace {
23+
24+
inline int64_t
25+
adaptive_start_index(int64_t out_idx, int64_t out_size, int64_t in_size) {
26+
return static_cast<int64_t>(
27+
std::floor(static_cast<float>(out_idx * in_size) / out_size));
28+
}
29+
30+
inline int64_t
31+
adaptive_end_index(int64_t out_idx, int64_t out_size, int64_t in_size) {
32+
return static_cast<int64_t>(
33+
std::ceil(static_cast<float>((out_idx + 1) * in_size) / out_size));
34+
}
35+
36+
} // namespace
37+
38+
Tensor& _adaptive_avg_pool2d_out(
39+
KernelRuntimeContext& ctx,
40+
const Tensor& in,
41+
IntArrayRef output_size,
42+
Tensor& out) {
43+
ET_KERNEL_CHECK(
44+
ctx,
45+
check_adaptive_avg_pool2d_args(in, output_size, out),
46+
InvalidArgument,
47+
out);
48+
49+
ET_KERNEL_CHECK(
50+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
51+
52+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
53+
54+
size_t output_ndim = 0;
55+
executorch::aten::SizesType output_sizes[kTensorDimensionLimit];
56+
get_adaptive_avg_pool2d_out_target_size(
57+
in, output_size, output_sizes, &output_ndim);
58+
59+
ET_KERNEL_CHECK(
60+
ctx,
61+
output_size_is_valid({output_sizes, output_ndim}, 2),
62+
InvalidArgument,
63+
out);
64+
65+
ET_KERNEL_CHECK(
66+
ctx,
67+
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
68+
InvalidArgument,
69+
out);
70+
71+
ScalarType in_type = in.scalar_type();
72+
73+
// @lint-ignore CLANGTIDY facebook-hte-CArray
74+
static constexpr const char op_name[] = "_adaptive_avg_pool2d.out";
75+
76+
ET_SWITCH_FLOATHBF16_TYPES_AND(Long, in_type, ctx, op_name, CTYPE, [&]() {
77+
const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
78+
CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
79+
80+
const size_t ndim = in.dim();
81+
const int64_t in_H = in.size(ndim - 2);
82+
const int64_t in_W = in.size(ndim - 1);
83+
const int64_t out_H = output_size[0];
84+
const int64_t out_W = output_size[1];
85+
86+
const size_t channels = in.size(ndim - 3);
87+
const size_t batch_size = ndim == 4 ? in.size(0) : 1;
88+
89+
const size_t in_plane_size = in_H * in_W;
90+
const size_t out_plane_size = out_H * out_W;
91+
92+
for (size_t b = 0; b < batch_size; ++b) {
93+
for (size_t c = 0; c < channels; ++c) {
94+
const size_t plane_idx = b * channels + c;
95+
const CTYPE* plane_in = in_ptr + plane_idx * in_plane_size;
96+
CTYPE* plane_out = out_ptr + plane_idx * out_plane_size;
97+
98+
for (int64_t oh = 0; oh < out_H; ++oh) {
99+
int64_t ih0 = adaptive_start_index(oh, out_H, in_H);
100+
int64_t ih1 = adaptive_end_index(oh, out_H, in_H);
101+
102+
for (int64_t ow = 0; ow < out_W; ++ow) {
103+
int64_t iw0 = adaptive_start_index(ow, out_W, in_W);
104+
int64_t iw1 = adaptive_end_index(ow, out_W, in_W);
105+
106+
float sum = 0;
107+
for (int64_t ih = ih0; ih < ih1; ++ih) {
108+
for (int64_t iw = iw0; iw < iw1; ++iw) {
109+
sum += plane_in[ih * in_W + iw];
110+
}
111+
}
112+
113+
int64_t count = (ih1 - ih0) * (iw1 - iw0);
114+
plane_out[oh * out_W + ow] =
115+
static_cast<CTYPE>(sum / static_cast<float>(count));
116+
}
117+
}
118+
}
119+
}
120+
});
121+
122+
return out;
123+
}
124+
125+
} // namespace native
126+
} // namespace executor
127+
} // namespace torch

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,53 @@ bool check_arange_args(double start, double end, double step, Tensor& out) {
262262
return true;
263263
}
264264

265+
bool check_adaptive_avg_pool2d_args(
266+
const Tensor& in,
267+
const IntArrayRef output_size,
268+
const Tensor& out) {
269+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
270+
271+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
272+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
273+
274+
ET_CHECK_OR_RETURN_FALSE(
275+
(in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
276+
(in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
277+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input; in.dim() = %" ET_PRI_TENSOR_DIM,
278+
in.dim());
279+
280+
ET_CHECK_OR_RETURN_FALSE(
281+
output_size.size() == 2,
282+
"output_size must have exactly 2 elements, but got %zu",
283+
output_size.size());
284+
285+
ET_CHECK_OR_RETURN_FALSE(
286+
output_size[0] > 0 && output_size[1] > 0,
287+
"output_size must be positive, but got (%" PRId64 ", %" PRId64 ")",
288+
output_size[0],
289+
output_size[1]);
290+
291+
return true;
292+
}
293+
294+
void get_adaptive_avg_pool2d_out_target_size(
295+
const Tensor& in,
296+
const IntArrayRef output_size,
297+
executorch::aten::SizesType* const out_sizes,
298+
size_t* const out_ndim) {
299+
*out_ndim = in.dim();
300+
301+
if (in.dim() == 4) {
302+
out_sizes[0] = in.size(0);
303+
out_sizes[1] = in.size(1);
304+
} else {
305+
out_sizes[0] = in.size(0);
306+
}
307+
308+
out_sizes[*out_ndim - 2] = output_size[0];
309+
out_sizes[*out_ndim - 1] = output_size[1];
310+
}
311+
265312
bool check_avg_pool2d_args(
266313
const Tensor& in,
267314
const IntArrayRef kernel_size,

kernels/portable/cpu/util/kernel_ops_util.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ void apply_kernel_2d_reduce_then_map_fn(
384384
// Operator specific utility functions
385385
//
386386

387+
bool check_adaptive_avg_pool2d_args(
388+
const Tensor& in,
389+
const IntArrayRef output_size,
390+
const Tensor& out);
391+
392+
void get_adaptive_avg_pool2d_out_target_size(
393+
const Tensor& in,
394+
const IntArrayRef output_size,
395+
executorch::aten::SizesType* const out_sizes,
396+
size_t* const out_ndim);
397+
387398
bool check_arange_args(double start, double end, double step, Tensor& out);
388399

389400
bool check_avg_pool2d_args(

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
# See the README.md file in this directory for a description of the syntax used
1818
# by this file.
1919

20+
- op: _adaptive_avg_pool2d.out
21+
kernels:
22+
- arg_meta: null
23+
kernel_name: torch::executor::_adaptive_avg_pool2d_out
24+
2025
- op: _cdist_forward.out
2126
kernels:
2227
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ set(all_test_sources
157157
"BinaryLogicalOpTest.cpp"
158158
"op__to_dim_order_copy_test.cpp"
159159
"op__clone_dim_order_test.cpp"
160+
"op__adaptive_avg_pool2d_test.cpp"
160161
"op_abs_test.cpp"
161162
"op_acos_test.cpp"
162163
"op_acosh_test.cpp"

0 commit comments

Comments
 (0)