Skip to content

Commit d7dce4a

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Implement max_pool2d_with_indices
Summary: Implements `max_pool2d_with_indices.out`. Reviewed By: guangy10, kirklandsign Differential Revision: D48405101 fbshipit-source-id: 3bc62b0337f101debfff0a44f47a90d8a4ad8f99
1 parent d8bb5d6 commit d7dce4a

7 files changed

+697
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstring>
10+
11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using Tensor = exec_aten::Tensor;
20+
using ScalarType = exec_aten::ScalarType;
21+
using IntArrayRef = exec_aten::ArrayRef<int64_t>;
22+
23+
std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(
24+
RuntimeContext& ctx,
25+
const Tensor& in,
26+
IntArrayRef kernel_size,
27+
IntArrayRef stride,
28+
IntArrayRef padding,
29+
IntArrayRef dilation,
30+
bool ceil_mode,
31+
Tensor& out,
32+
Tensor& indices) {
33+
std::tuple<Tensor&, Tensor&> ret_val(out, indices);
34+
35+
ET_KERNEL_CHECK(
36+
ctx,
37+
check_max_pool2d_with_indices_args(
38+
in, kernel_size, stride, padding, dilation, ceil_mode, out, indices),
39+
InvalidArgument,
40+
ret_val);
41+
42+
size_t output_ndim = 0;
43+
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
44+
get_max_pool2d_with_indices_out_target_size(
45+
in,
46+
kernel_size,
47+
stride,
48+
padding,
49+
dilation,
50+
ceil_mode,
51+
output_sizes,
52+
&output_ndim);
53+
54+
ET_KERNEL_CHECK(
55+
ctx,
56+
output_size_is_valid({output_sizes, output_ndim}),
57+
InvalidArgument,
58+
out);
59+
60+
ET_KERNEL_CHECK(
61+
ctx,
62+
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
63+
InvalidArgument,
64+
ret_val);
65+
66+
ET_KERNEL_CHECK(
67+
ctx,
68+
resize_tensor(indices, {output_sizes, output_ndim}) == Error::Ok,
69+
InvalidArgument,
70+
ret_val);
71+
72+
ScalarType in_type = in.scalar_type();
73+
ET_SWITCH_REAL_TYPES(in_type, ctx, __func__, CTYPE, [&]() {
74+
apply_kernel_2d_reduce_fn<CTYPE>(
75+
[](const CTYPE in_val, int64_t in_idx, CTYPE accum, int64_t accum_idx) {
76+
if (in_val > accum) {
77+
return std::tuple<CTYPE, int64_t>(in_val, in_idx);
78+
}
79+
return std::tuple<CTYPE, int64_t>(accum, accum_idx);
80+
},
81+
in,
82+
kernel_size,
83+
stride,
84+
padding,
85+
dilation,
86+
out,
87+
{indices});
88+
});
89+
90+
return ret_val;
91+
}
92+
93+
} // namespace native
94+
} // namespace executor
95+
} // namespace torch

kernels/portable/cpu/targets.bzl

+6
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,12 @@ _ATEN_OPS = (
480480
"//executorch/kernels/portable/cpu/util:reduce_util",
481481
],
482482
),
483+
op_target(
484+
name = "op_max_pool2d_with_indices",
485+
deps = [
486+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
487+
],
488+
),
483489
op_target(
484490
name = "op_mean",
485491
deps = [

kernels/portable/cpu/util/kernel_ops_util.cpp

+63
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ bool int_array_all_ge(IntArrayRef array, int64_t val) {
4040
return true;
4141
}
4242

43+
bool kernel_size_is_valid(IntArrayRef kernel_size, size_t kernel_ndim) {
44+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
45+
kernel_size.size() == kernel_ndim,
46+
"Expected kernel_size to have size %zu but got %zd",
47+
kernel_ndim,
48+
kernel_size.size());
49+
ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(kernel_size, 1));
50+
return true;
51+
}
52+
4353
bool stride_is_valid(IntArrayRef stride, size_t kernel_ndim) {
4454
ET_LOG_MSG_AND_RETURN_IF_FALSE(
4555
stride.size() > 0 && stride.size() <= kernel_ndim,
@@ -267,5 +277,58 @@ void get_convolution_out_target_size(
267277
in, {kernel_size, kernel_ndim}, stride, padding, dilation, out_sizes);
268278
}
269279

280+
bool check_max_pool2d_with_indices_args(
281+
const Tensor& in,
282+
IntArrayRef kernel_size,
283+
IntArrayRef stride,
284+
IntArrayRef padding,
285+
IntArrayRef dilation,
286+
bool ceil_mode,
287+
Tensor& out,
288+
Tensor& indices) {
289+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
290+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
291+
indices.scalar_type() == ScalarType::Long,
292+
"Expected indices to have type of Long, but found %s",
293+
toString(indices.scalar_type()));
294+
295+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
296+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
297+
298+
ET_LOG_AND_RETURN_IF_FALSE(kernel_size_is_valid(kernel_size, 2));
299+
if (stride.size() > 0) {
300+
ET_LOG_AND_RETURN_IF_FALSE(stride_is_valid(kernel_size, 2));
301+
}
302+
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(padding, kernel_size, 2, true));
303+
if (dilation.size() > 0) {
304+
ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(dilation, 2));
305+
}
306+
307+
return true;
308+
}
309+
310+
void get_max_pool2d_with_indices_out_target_size(
311+
const Tensor& in,
312+
IntArrayRef kernel_size,
313+
IntArrayRef stride,
314+
IntArrayRef padding,
315+
IntArrayRef dilation,
316+
bool ceil_mode,
317+
exec_aten::SizesType* out_sizes,
318+
size_t* out_ndim) {
319+
*out_ndim = in.dim();
320+
321+
// Batch dim is optional, so in can be either 3 or 4 dim.
322+
if (in.dim() == 4) {
323+
out_sizes[0] = in.size(0);
324+
out_sizes[1] = in.size(1);
325+
} else {
326+
out_sizes[0] = in.size(0);
327+
}
328+
329+
calculate_kernel_output_sizes(
330+
in, kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
331+
}
332+
270333
} // namespace executor
271334
} // namespace torch

0 commit comments

Comments
 (0)