From 1eb4e2b072b1bed535e8ac3da83d2c6525e84a14 Mon Sep 17 00:00:00 2001 From: Salil Desai Date: Tue, 22 Aug 2023 10:07:31 -0700 Subject: [PATCH] Dtype compliance: clamp (#83) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/83 Reland of D47573238 (Adjusting clamp portable op have near-complete Dtype compliance with aten version), but with the supernova build fixed. In the original diff, we were doing something like: ``` if isFloatingType(out_type) assert (double(min_val) >= out_type_min_value) ``` One problem with this is that if out_type is ```long```, then this comparison causes the long min value to be converted to double, which is unsafe. Even though it's impossible for this to happen when running the code due to the if statement, the compiler isn't smart enough to know that, so it was giving errors. The fix is to wrap these checks within ET_SWITCH_INT_TYPES or ET_SWITCH_FLOAT_TYPES macros within the if statements Differential Revision: D48491102 fbshipit-source-id: f24318fcd332343c6529551c981f0a94d24fd9af --- kernels/portable/cpu/op_clamp.cpp | 114 +++++++++++++++++++++++------- kernels/test/op_clamp_test.cpp | 4 +- 2 files changed, 92 insertions(+), 26 deletions(-) diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 8a824444814..39c4765d322 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -73,6 +73,43 @@ using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; using Tensor = exec_aten::Tensor; +namespace { + +template +/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */ +bool is_out_of_bounds(CTYPE_VAL val) { + const CTYPE_CAST val_cast = static_cast(val); + return val_cast < std::numeric_limits::lowest() || + val_cast > std::numeric_limits::max(); +} + +void check_bounds( + const Scalar& val_scalar, + const torch::executor::native::ScalarType& val_type, + const torch::executor::native::ScalarType& out_type, + const char* val_name) { + ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp", CTYPE_VAL, [&]() { + CTYPE_VAL val = 0; + ET_EXTRACT_SCALAR(val_scalar, val); + if (isIntegralType(out_type, /*includeBool=*/false)) { + ET_SWITCH_INT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + if (is_out_of_bounds(val)) { + ET_CHECK_MSG(false, "%s value out of bounds", val_name); + } + }); + } else if (isFloatingType(out_type)) { + ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + if (std::isfinite(val) && + is_out_of_bounds(val)) { + ET_CHECK_MSG(false, "%s value out of bounds", val_name); + } + }); + } + }); +} + +} // namespace + Tensor& clamp_out( RuntimeContext& ctx, const Tensor& in, @@ -84,38 +121,67 @@ Tensor& clamp_out( Error err = resize_tensor(out, in.sizes()); ET_CHECK_MSG(err == Error::Ok, "Could not resize output"); - ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out); + ScalarType in_type = in.scalar_type(); + ScalarType min_type = in_type; + ScalarType max_type = in_type; + ScalarType common_type = in_type; + ScalarType out_type = out.scalar_type(); + + bool has_min = min_opt.has_value(); + if (has_min) { + min_type = utils::get_scalar_dtype(min_opt.value()); + common_type = utils::promote_type_with_scalar(common_type, min_opt.value()); + check_bounds(min_opt.value(), min_type, out_type, "minimum"); + } + bool has_max = max_opt.has_value(); + if (has_max) { + max_type = utils::get_scalar_dtype(max_opt.value()); + common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); + check_bounds(max_opt.value(), max_type, out_type, "maximum"); + } - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() { + ET_CHECK_MSG( + has_min || has_max, "At least one of 'min' or 'max' must not be None"); + + ET_CHECK(common_type == out_type); + + ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { // Extract optional min value - CTYPE min = 0; - bool has_min = min_opt.has_value(); + CTYPE_OUT min = 0; if (has_min) { - bool ok = utils::extract_scalar(min_opt.value(), &min); - ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range"); + ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() { + CTYPE_MIN min_val = 0; + ET_EXTRACT_SCALAR(min_opt.value(), min_val); + min = static_cast(min_val); + }); } + // Extract optional max value - CTYPE max = 0; - bool has_max = max_opt.has_value(); + CTYPE_OUT max = 0; if (has_max) { - bool ok = utils::extract_scalar(max_opt.value(), &max); - ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range"); + ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() { + CTYPE_MAX max_val = 0; + ET_EXTRACT_SCALAR(max_opt.value(), max_val); + max = static_cast(max_val); + }); } - apply_unary_map_fn( - [has_min, min, has_max, max](const CTYPE val_in) { - CTYPE val_out = val_in; - if (has_min) { - val_out = max_override(val_out, min); - } - if (has_max) { - val_out = min_override(val_out, max); - } - return val_out; - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); + ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() { + apply_unary_map_fn( + [has_min, min, has_max, max](const CTYPE_IN val_in) { + CTYPE_OUT val_out = static_cast(val_in); + if (has_min) { + val_out = max_override(val_out, min); + } + if (has_max) { + val_out = min_override(val_out, max); + } + return val_out; + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); }); return out; diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index b505f91bc85..08d898733e1 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) { #ifndef USE_ATEN_LIB TEST(OpClampOutTest, IntTensorTooSmallClampDies) { - // Cannot be represented by a uint32_t. + // Cannot be represented by a int32_t. expect_bad_clamp_value_dies(-2147483649); } TEST(OpClampOutTest, IntTensorTooLargeClampDies) { - // Cannot be represented by a uint32_t. + // Cannot be represented by a int32_t. expect_bad_clamp_value_dies(2147483648); } #endif