diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 4665c3d665b..d0b7c882f8e 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, @@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 704d8d06c5c..5cd17223d80 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -49,6 +49,32 @@ void test_dtype() { EXPECT_TENSOR_EQ(out, expected); } +template +void test_input_dtype() { + TensorFactory tf_input; + + Tensor input = tf_input.full({3, 5}, 4); + double scale = 0.5; + int64_t zero_point = 108; + int64_t quant_min = 0; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // 4 / 0.5 + 108 = 116 + Tensor expected = tfo.full({3, 5}, 116); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, AllInputDtypesSupported) { + test_input_dtype(); + test_input_dtype(); + test_input_dtype(); +} + TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); test_dtype(); @@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); } +TEST(OpQuantizeOutTest, DoubleInputTest) { + TensorFactory tf_double; + + // Test with a more complex value that might have precision differences + Tensor input = tf_double.full({2, 3}, 3.14159265359); + double scale = 0.01; + int64_t zero_point = -100; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // 3.14159265359 / 0.01 - 100 = 214.159265359 + Tensor expected = tfo.full({2, 3}, 214); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, HalfInputTest) { + TensorFactory tf_half; + + Tensor input = tf_half.full({2, 3}, 2.5); + double scale = 0.5; + int64_t zero_point = 10; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // 2.5 / 0.5 + 10 = 15 + Tensor expected = tfo.full({2, 3}, 15); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpQuantizeOutTest, TensorArgOverload) { TensorFactory tf_float; TensorFactory tf_double;