From 77504acb23d82d539152a7a9357dab9891dfa8ae Mon Sep 17 00:00:00 2001 From: morelos Date: Fri, 13 Jun 2025 15:49:26 -0700 Subject: [PATCH] [ET-VK][Ops] quantize ops skeleton test framework Pull Request resolved: https://github.com/pytorch/executorch/pull/11366 # Context In this diff we plan on creating the skeleton test framework for quantization. This is necessary as we need a reference to test our vulkan implementation of the quantization operator against an existing cpu implementation. This test framework is heavily inspired by [sdpa_test.cpp](https://github.com/pytorch/executorch/blob/main/backends/vulkan/test/op_tests/sdpa_test.cpp). We make use of the [op_quantize.cpp](https://github.com/pytorch/executorch/blob/main/kernels/quantized/cpu/op_quantize.cpp) cpu implementation of the `quantize_per_tensor`, and the `quantize_per_token` operators. An explanation for the operator is included where the actual vulkan implementation is created in a future diff along this stack. # Changes The main thing in this difference is the creation of a new test framework `quantize_test.cpp`, and also including it in targets.bzl such that we can properly call the test. As this is inspired by sdpa_test.cpp, we also follow a similar format. First we have forward declarations of the functions that we wish to test against (quantize_per_tensor, and quantize_per_token). Then we also have wrappers for the functions without context, and finally wrappers for the ATen implementations of the same operators using the `WRAP_TO_ATEN` macro. We don't need context as this is merely for testing. We also have a utility function to test the quantize arguments that will be used when actually using the vulkan implementation. This utility function is just for a sanity check. ghstack-source-id: 290376488 @exported-using-ghexport Differential Revision: [D75959066](https://our.internmc.facebook.com/intern/diff/D75959066/) --- .../vulkan/test/op_tests/quantize_test.cpp | 158 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 9 + 2 files changed, 167 insertions(+) create mode 100644 backends/vulkan/test/op_tests/quantize_test.cpp diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp new file mode 100644 index 0000000000..6ff61dac19 --- /dev/null +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& quantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +// Wrapper function for quantize_per_tensor_out without context +Tensor& quantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// Wrapper function for quantize_per_token_out without context +Tensor& quantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// ATen wrapper for quantize_per_tensor +at::Tensor quantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +// ATen wrapper for quantize_per_token +at::Tensor quantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_quantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType out_dtype) { + using namespace vkcompute; + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + switch (out_dtype) { + case c10::kByte: + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + break; + case c10::kChar: + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + break; + case c10::kBits16: + case c10::kUInt16: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + case c10::kShort: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + case c10::kInt: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + default: + VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); + } + VK_CHECK_COND( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for dtype, expected quant_min_lower_bound: ", + quant_min_lower_bound, + " actual quant_min: ", + quant_min); + + VK_CHECK_COND( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for dtype, expected quant_max_upper_bound: ", + quant_max_upper_bound, + " actual quant_max: ", + quant_max); +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 6fcf2d8353..cb5b49c390 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -177,6 +177,15 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/tensor:tensor", ] ) + define_test_targets( + "quantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_quantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) define_test_targets( "linear_weight_int4_test", extra_deps = [