From 4dbd4606266e3e7968a9651941d5886ee94cad42 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 4 Mar 2025 14:44:09 -0800 Subject: [PATCH] Add max_pool2d_with_indices_backward (#8940) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/8940 Reviewed By: JacobSzwejbka Differential Revision: D70577129 --- kernels/aten/functions.yaml | 2 + .../op_max_pool2d_with_indices_backward.cpp | 180 ++++++++++++++ kernels/portable/cpu/util/kernel_ops_util.cpp | 4 +- kernels/portable/cpu/util/kernel_ops_util.h | 4 +- kernels/portable/functions.yaml | 5 + ..._max_pool2d_with_indices_backward_test.cpp | 229 ++++++++++++++++++ kernels/test/targets.bzl | 1 + .../kernels/portable/op_registration_util.bzl | 6 + 8 files changed, 427 insertions(+), 4 deletions(-) create mode 100644 kernels/portable/cpu/op_max_pool2d_with_indices_backward.cpp create mode 100644 kernels/test/op_max_pool2d_with_indices_backward_test.cpp diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 463ef0f9d32..072b7504050 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -249,6 +249,8 @@ - op: max_pool2d_with_indices.out +- op: max_pool2d_with_indices_backward.grad_input + - op: max.dim_max - op: max.unary_out diff --git a/kernels/portable/cpu/op_max_pool2d_with_indices_backward.cpp b/kernels/portable/cpu/op_max_pool2d_with_indices_backward.cpp new file mode 100644 index 00000000000..5edce5a2c67 --- /dev/null +++ b/kernels/portable/cpu/op_max_pool2d_with_indices_backward.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = executorch::aten::Tensor; +using ScalarType = executorch::aten::ScalarType; +using IntArrayRef = executorch::aten::ArrayRef; + +namespace { + +bool check_max_pool2d_backward_args( + const Tensor& grad_output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const Tensor& indices, + const Tensor& grad_input) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input)); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input)); + + ET_CHECK_OR_RETURN_FALSE( + check_max_pool2d_with_indices_args( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + grad_output, + indices), + "Invalid max_pool_2d arguments"); + + size_t output_ndim = 0; + // @lint-ignore CLANGTIDY facebook-hte-CArray + executorch::aten::SizesType output_sizes[kTensorDimensionLimit]; + get_max_pool2d_with_indices_out_target_size( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + output_sizes, + &output_ndim); + + ET_LOG_AND_RETURN_IF_FALSE( + output_size_is_valid({output_sizes, output_ndim}, 2)); + + ET_CHECK_OR_RETURN_FALSE( + grad_output.dim() == input.dim(), + "grad_output should have same number of dimensions as input"); + + ET_LOG_AND_RETURN_IF_FALSE( + tensor_has_expected_size(grad_output, {output_sizes, output_ndim})); + + return true; +} + +template +void max_pool_backward_impl( + const Tensor& grad_input, + const Tensor& grad_output, + const Tensor& indices) { + const CTYPE* grad_output_data = grad_output.const_data_ptr(); + const int64_t* indices_data = indices.const_data_ptr(); + CTYPE* grad_input_data = grad_input.mutable_data_ptr(); + + // treat batch size and channels as one dimension + // + // MaxPool2d: + // ndim == 3: CHW + // ndim == 4: NCHW + // + // MaxPool3d: + // ndim == 4: CDHW + // ndim == 5: NCDHW + int64_t ndim = grad_output.dim(); + int64_t channels; + if (is_3d) { + channels = ndim == 4 ? grad_output.size(0) + : grad_output.size(0) * grad_output.size(1); + } else { + channels = ndim == 3 ? grad_output.size(0) + : grad_output.size(0) * grad_output.size(1); + } + int64_t input_depth = is_3d ? grad_input.size(-3) : 1; + + int64_t input_height = grad_input.size(ndim - 2); + int64_t input_width = grad_input.size(ndim - 1); + int64_t output_depth = is_3d ? grad_output.size(ndim - 3) : 1; + int64_t output_height = grad_output.size(ndim - 2); + int64_t output_width = grad_output.size(ndim - 1); + + for (int64_t c = 0; c < channels; ++c) { + CTYPE* grad_input_ptr = + grad_input_data + c * input_depth * input_height * input_width; + const CTYPE* grad_output_ptr = + grad_output_data + c * output_depth * output_height * output_width; + const int64_t* indices_ptr = + indices_data + c * output_depth * output_height * output_width; + + for (int64_t od = 0; od < output_depth; od++) { + for (int64_t oh = 0; oh < output_height; oh++) { + for (int64_t ow = 0; ow < output_width; ow++) { + // retrieve position of max + int64_t index = + od * output_height * output_width + oh * output_width + ow; + int64_t maxindex = indices_ptr[index]; + if (maxindex != -1) { + // update gradient + grad_input_ptr[maxindex] += grad_output_ptr[index]; + } + } + } + } + } +} + +} // namespace + +Tensor& max_pool2d_with_indices_backward_out( + KernelRuntimeContext& ctx, + const Tensor& grad_output, + const Tensor& input, + ET_UNUSED IntArrayRef kernel_size, + ET_UNUSED IntArrayRef stride, + ET_UNUSED IntArrayRef padding, + ET_UNUSED IntArrayRef dilation, + ET_UNUSED bool ceil_mode, + const Tensor& indices, + Tensor& grad_input) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + check_max_pool2d_backward_args( + grad_output, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, + grad_input), + InvalidArgument, + grad_input); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(grad_input, input.sizes()) == Error::Ok, + InvalidArgument, + grad_input); + + constexpr auto name = "max_pool2d_with_indices_backward.grad_input"; + + ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() { + max_pool_backward_impl(grad_input, grad_output, indices); + }); + + return grad_input; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/kernel_ops_util.cpp b/kernels/portable/cpu/util/kernel_ops_util.cpp index c6a38fbb2f0..00b088a5cec 100644 --- a/kernels/portable/cpu/util/kernel_ops_util.cpp +++ b/kernels/portable/cpu/util/kernel_ops_util.cpp @@ -470,8 +470,8 @@ bool check_max_pool2d_with_indices_args( IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, - Tensor& out, - Tensor& indices) { + const Tensor& out, + const Tensor& indices) { ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); ET_CHECK_OR_RETURN_FALSE( indices.scalar_type() == ScalarType::Long, diff --git a/kernels/portable/cpu/util/kernel_ops_util.h b/kernels/portable/cpu/util/kernel_ops_util.h index 5951b7d0492..8028f254eb4 100644 --- a/kernels/portable/cpu/util/kernel_ops_util.h +++ b/kernels/portable/cpu/util/kernel_ops_util.h @@ -442,8 +442,8 @@ bool check_max_pool2d_with_indices_args( IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, - Tensor& out, - Tensor& indices); + const Tensor& out, + const Tensor& indices); void get_max_pool2d_with_indices_out_target_size( const Tensor& in, diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 3221b8fe349..21cf2f198bb 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -572,6 +572,11 @@ - arg_meta: null kernel_name: torch::executor::max_pool2d_with_indices_out +- op: max_pool2d_with_indices_backward.grad_input + kernels: + - arg_meta: null + kernel_name: torch::executor::max_pool2d_with_indices_backward_out + - op: mean.out kernels: - arg_meta: null diff --git a/kernels/test/op_max_pool2d_with_indices_backward_test.cpp b/kernels/test/op_max_pool2d_with_indices_backward_test.cpp new file mode 100644 index 00000000000..c647ad05c5f --- /dev/null +++ b/kernels/test/op_max_pool2d_with_indices_backward_test.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; + +class OpMaxPool2DWithIndicesBackwardOutTest : public OperatorTest { + protected: + executorch::aten::Tensor& op_max_pool2d_with_indices_backward_out( + const executorch::aten::Tensor& grad_output, + const executorch::aten::Tensor& input, + executorch::aten::ArrayRef kernel_size, + executorch::aten::ArrayRef stride, + executorch::aten::ArrayRef padding, + executorch::aten::ArrayRef dilation, + bool ceil_mode, + const executorch::aten::Tensor& indices, + executorch::aten::Tensor& grad_input) { + return torch::executor::aten::max_pool2d_with_indices_backward_outf( + context_, + grad_output, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, + grad_input); + } + + template + void test_4d_dtype() { + torch::executor::testing::TensorFactory tf; + torch::executor::testing::TensorFactory + tfLong; + + executorch::aten::Tensor grad_output = tf.make( + {2, 3, 4, 4}, + {69, 97, 97, 99, 69, 97, 97, 99, 12, 79, 85, 85, 77, 77, 85, 85, + 87, 73, 73, 68, 87, 94, 94, 68, -30, 94, 94, 8, 71, 74, 77, 77, + 4, -8, -12, -46, 87, 90, 90, -45, 87, 90, 90, 17, 63, 28, 88, 88, + 83, 83, 61, 61, 83, 83, 47, 49, 16, 47, 47, 74, 90, 90, 73, 74, + 41, 81, 81, 29, 84, 81, 81, 17, 84, 45, 99, 99, 16, 45, 99, 99, + 54, 54, 5, 29, 54, 68, 68, 29, 90, 90, 68, 90, 99, 99, 65, 90}); + + executorch::aten::Tensor input = tf.make( + {2, 3, 5, 5}, + {28, -38, -7, -13, 70, 53, 69, 97, 25, 99, -72, -87, 79, 42, + -24, -15, 12, -86, 85, 0, 67, 77, 53, -61, 50, 3, 42, -37, + 51, -60, 87, 32, 73, 68, -84, -98, -30, 94, 1, -86, -56, -68, + 74, -51, 8, 71, -53, 4, 77, -89, 4, -46, -46, -92, -85, -23, + -8, -12, -46, -88, 66, 87, 90, -45, -78, 63, 28, 28, -30, 17, + -16, 5, 11, 88, -47, 72, 32, -7, 61, -63, -22, 83, -40, -78, + 49, -39, -89, 47, -61, 7, 16, -96, -22, 8, 74, 12, 90, 73, + -71, -10, 41, 1, 10, -34, 29, -27, 26, 81, -8, 17, 84, -23, + -53, -26, -67, -90, 16, 45, 99, 56, -87, -65, -79, 31, 79, 6, + 44, -55, -5, -68, -38, 54, -3, 5, 29, -39, 26, 68, -24, -53, + 51, 90, 65, 43, 90, -41, 99, 6, -31, -94}); + + ::std::vector kernel_size_vec = {2, 2}; + executorch::aten::ArrayRef kernel_size = + executorch::aten::ArrayRef( + kernel_size_vec.data(), kernel_size_vec.size()); + ::std::vector stride_vec = {1, 1}; + executorch::aten::ArrayRef stride = + executorch::aten::ArrayRef( + stride_vec.data(), stride_vec.size()); + ::std::vector padding_vec = {0, 0}; + executorch::aten::ArrayRef padding = + executorch::aten::ArrayRef( + padding_vec.data(), padding_vec.size()); + ::std::vector dilation_vec = {1, 1}; + executorch::aten::ArrayRef dilation = + executorch::aten::ArrayRef( + dilation_vec.data(), dilation_vec.size()); + bool ceil_mode = false; + executorch::aten::Tensor indices = tfLong.make( + {2, 3, 4, 4}, + {6, 7, 7, 9, 6, 7, 7, 9, 16, 12, 18, 18, 21, 21, 18, 18, + 5, 7, 7, 8, 5, 12, 12, 8, 11, 12, 12, 19, 20, 17, 23, 23, + 0, 6, 7, 8, 11, 12, 12, 13, 11, 12, 12, 19, 15, 16, 23, 23, + 6, 6, 3, 3, 6, 6, 12, 9, 15, 12, 12, 19, 21, 21, 22, 19, + 0, 7, 7, 4, 10, 7, 7, 9, 10, 17, 18, 18, 16, 17, 18, 18, + 6, 6, 8, 9, 6, 12, 12, 9, 16, 16, 12, 19, 21, 21, 17, 19}); + executorch::aten::Tensor grad_input = tf.zeros({2, 3, 5, 5}); + executorch::aten::Tensor grad_input_expected = tf.make( + {2, 3, 5, 5}, + {0, 0, 0, 0, 0, 0, 138, 388, 0, 198, 0, 0, 79, 0, 0, + 0, 12, 0, 340, 0, 0, 154, 0, 0, 0, 0, 0, 0, 0, 0, + 174, 0, 146, 136, 0, 0, -30, 376, 0, 0, 0, 0, 74, 0, 8, + 71, 0, 0, 154, 0, 4, 0, 0, 0, 0, 0, -8, -12, -46, 0, + 0, 174, 360, -45, 0, 63, 28, 0, 0, 17, 0, 0, 0, 176, 0, + 0, 0, 0, 122, 0, 0, 332, 0, 0, 49, 0, 0, 141, 0, 0, + 16, 0, 0, 0, 148, 0, 180, 73, 0, 0, 41, 0, 0, 0, 29, + 0, 0, 324, 0, 17, 168, 0, 0, 0, 0, 0, 16, 90, 396, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 162, 0, 5, 58, + 0, 0, 204, 0, 0, 0, 180, 65, 0, 180, 0, 198, 0, 0, 0}); + op_max_pool2d_with_indices_backward_out( + grad_output, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, + grad_input); + EXPECT_TENSOR_CLOSE(grad_input, grad_input_expected); + } + + template + void test_3d_dtype() { + torch::executor::testing::TensorFactory tf; + torch::executor::testing::TensorFactory + tfLong; + + executorch::aten::Tensor grad_output = + tf.make({2, 5, 5}, {89, 89, 89, 20, 20, 89, 89, 86, 49, 80, 89, 89, 99, + 99, 99, 84, 84, 86, 86, 86, 51, 86, 86, 86, 62, 42, + 67, 85, 85, 85, 75, 75, 42, 42, 74, 75, 98, 98, 98, + 61, 95, 98, 98, 98, 93, 88, 88, 13, 13, 67}); + + executorch::aten::Tensor input = tf.make( + {2, 12, 12}, + {73, 15, 30, 89, -55, -62, 25, -50, -47, 12, -73, -89, 53, -63, + -44, 86, 53, -84, -6, 20, -24, -43, -11, -34, -7, -13, 74, 33, + -44, 49, -59, -88, -46, -33, 48, 80, 38, -58, 0, -48, -46, -87, + -66, 14, -68, -77, -50, -15, 86, 89, -37, 7, -16, -6, 55, 40, + -83, -77, -55, 32, -17, -83, 43, 17, 2, -51, 20, -77, -68, -72, + -47, -78, -49, -52, -7, -25, -77, -8, -3, 99, 71, 19, 21, -47, + 44, -90, -75, -87, 79, -42, -90, 22, 2, 73, -65, -50, -71, 19, + -60, -91, -43, -60, 16, 86, -93, -78, 82, 14, 20, 19, 33, 84, + 60, 41, 2, -4, -52, 74, -40, -60, 88, 51, -59, 49, -81, -93, + 43, -99, 40, -84, 76, 27, 59, -19, -55, -50, 81, 86, -19, 51, + 70, -90, 74, 62, 0, -31, -71, 42, 42, 67, 26, 85, -11, -34, + -97, 5, -45, -50, 74, -62, -81, -84, 70, 33, -27, -54, 94, 74, + -30, 16, 39, 0, 0, -80, 85, 42, 13, -82, -30, -95, 34, -60, + -51, -10, -30, -65, -96, -95, 60, -33, 67, -88, -26, 75, 29, -27, + -28, 21, -2, -29, 11, -68, -36, -85, -4, 9, -31, -63, 98, -1, + 17, 61, -50, 41, -18, -92, -50, -40, 14, 18, 22, 10, 58, -86, + -9, 5, -69, -50, -26, 26, 57, -94, -53, 98, 37, 35, -20, -9, + -13, -41, 41, 95, 82, -71, -43, -37, -91, -14, -55, 52, -30, 93, + -26, 83, 2, -63, 52, 31, 57, 42, -2, -45, 99, -18, 38, 88, + 36, -36, -35, 13, -31, -50, 10, -38, 1, 67, 3, -87, 42, -31, + -77, -7, -94, -99, 24, -21, -98, 15}); + ::std::vector kernel_size_vec = {4, 3}; + executorch::aten::ArrayRef kernel_size = + executorch::aten::ArrayRef( + kernel_size_vec.data(), kernel_size_vec.size()); + ::std::vector stride_vec = {3, 2}; + executorch::aten::ArrayRef stride = + executorch::aten::ArrayRef( + stride_vec.data(), stride_vec.size()); + ::std::vector padding_vec = {2, 1}; + executorch::aten::ArrayRef padding = + executorch::aten::ArrayRef( + padding_vec.data(), padding_vec.size()); + ::std::vector dilation_vec = {1, 2}; + executorch::aten::ArrayRef dilation = + executorch::aten::ArrayRef( + dilation_vec.data(), dilation_vec.size()); + bool ceil_mode = false; + executorch::aten::Tensor indices = tfLong.make( + {2, 5, 5}, + {3, 3, 3, 19, 19, 49, 49, 15, 29, 35, 49, 49, 79, + 79, 79, 111, 111, 103, 103, 103, 121, 137, 137, 137, 143, 3, + 5, 7, 7, 7, 49, 49, 31, 31, 23, 49, 89, 89, 89, + 67, 97, 89, 89, 89, 107, 121, 121, 125, 125, 131}); + executorch::aten::Tensor grad_input = tf.zeros({2, 12, 12}); + executorch::aten::Tensor grad_input_expected = tf.make( + {2, 12, 12}, + {0, 0, 0, 267, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 86, 0, 0, + 0, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 0, 0, 0, 0, 0, 80, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 356, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 297, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 258, 0, 0, 0, 0, + 0, 0, 0, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 258, 0, 0, 0, 0, 0, 62, + 0, 0, 0, 42, 0, 67, 0, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 84, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 225, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 61, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 588, + 0, 0, 0, 0, 0, 0, 0, 95, 0, 0, 0, 0, 0, 0, 0, 0, 0, 93, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 176, 0, 0, 0, 26, + 0, 0, 0, 0, 0, 67, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + op_max_pool2d_with_indices_backward_out( + grad_output, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, + grad_input); + EXPECT_TENSOR_CLOSE(grad_input, grad_input_expected); + } +}; + +TEST_F(OpMaxPool2DWithIndicesBackwardOutTest, SanityTest4D) { +#define TEST_ENTRY(ctype, dtype) \ + test_4d_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpMaxPool2DWithIndicesBackwardOutTest, SanityTest3D) { +#define TEST_ENTRY(ctype, dtype) \ + test_3d_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 91f2121bebc..14fc7323858 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -261,6 +261,7 @@ def define_common_targets(): _common_op_test("op_masked_select_test", ["aten", "portable"]) _common_op_test("op_max_test", ["aten", "portable"]) _common_op_test("op_max_pool2d_with_indices_test", ["aten", "portable"]) + _common_op_test("op_max_pool2d_with_indices_backward_test", ["aten", "portable"]) _common_op_test("op_maximum_test", ["aten", "portable"]) _common_op_test("op_mean_test", ["aten", "portable"]) _common_op_test("op_min_test", ["aten", "portable"]) diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index f5ddae06b6a..3907b39279a 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -817,6 +817,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:kernel_ops_util", ], ), + op_target( + name = "op_max_pool2d_with_indices_backward", + deps = [ + "//executorch/kernels/portable/cpu/util:kernel_ops_util", + ], + ), op_target( name = "op_mean", deps = [