diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 8a824444814..39c4765d322 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -73,6 +73,43 @@ using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; using Tensor = exec_aten::Tensor; +namespace { + +template +/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */ +bool is_out_of_bounds(CTYPE_VAL val) { + const CTYPE_CAST val_cast = static_cast(val); + return val_cast < std::numeric_limits::lowest() || + val_cast > std::numeric_limits::max(); +} + +void check_bounds( + const Scalar& val_scalar, + const torch::executor::native::ScalarType& val_type, + const torch::executor::native::ScalarType& out_type, + const char* val_name) { + ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp", CTYPE_VAL, [&]() { + CTYPE_VAL val = 0; + ET_EXTRACT_SCALAR(val_scalar, val); + if (isIntegralType(out_type, /*includeBool=*/false)) { + ET_SWITCH_INT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + if (is_out_of_bounds(val)) { + ET_CHECK_MSG(false, "%s value out of bounds", val_name); + } + }); + } else if (isFloatingType(out_type)) { + ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + if (std::isfinite(val) && + is_out_of_bounds(val)) { + ET_CHECK_MSG(false, "%s value out of bounds", val_name); + } + }); + } + }); +} + +} // namespace + Tensor& clamp_out( RuntimeContext& ctx, const Tensor& in, @@ -84,38 +121,67 @@ Tensor& clamp_out( Error err = resize_tensor(out, in.sizes()); ET_CHECK_MSG(err == Error::Ok, "Could not resize output"); - ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out); + ScalarType in_type = in.scalar_type(); + ScalarType min_type = in_type; + ScalarType max_type = in_type; + ScalarType common_type = in_type; + ScalarType out_type = out.scalar_type(); + + bool has_min = min_opt.has_value(); + if (has_min) { + min_type = utils::get_scalar_dtype(min_opt.value()); + common_type = utils::promote_type_with_scalar(common_type, min_opt.value()); + check_bounds(min_opt.value(), min_type, out_type, "minimum"); + } + bool has_max = max_opt.has_value(); + if (has_max) { + max_type = utils::get_scalar_dtype(max_opt.value()); + common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); + check_bounds(max_opt.value(), max_type, out_type, "maximum"); + } - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() { + ET_CHECK_MSG( + has_min || has_max, "At least one of 'min' or 'max' must not be None"); + + ET_CHECK(common_type == out_type); + + ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { // Extract optional min value - CTYPE min = 0; - bool has_min = min_opt.has_value(); + CTYPE_OUT min = 0; if (has_min) { - bool ok = utils::extract_scalar(min_opt.value(), &min); - ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range"); + ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() { + CTYPE_MIN min_val = 0; + ET_EXTRACT_SCALAR(min_opt.value(), min_val); + min = static_cast(min_val); + }); } + // Extract optional max value - CTYPE max = 0; - bool has_max = max_opt.has_value(); + CTYPE_OUT max = 0; if (has_max) { - bool ok = utils::extract_scalar(max_opt.value(), &max); - ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range"); + ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() { + CTYPE_MAX max_val = 0; + ET_EXTRACT_SCALAR(max_opt.value(), max_val); + max = static_cast(max_val); + }); } - apply_unary_map_fn( - [has_min, min, has_max, max](const CTYPE val_in) { - CTYPE val_out = val_in; - if (has_min) { - val_out = max_override(val_out, min); - } - if (has_max) { - val_out = min_override(val_out, max); - } - return val_out; - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); + ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() { + apply_unary_map_fn( + [has_min, min, has_max, max](const CTYPE_IN val_in) { + CTYPE_OUT val_out = static_cast(val_in); + if (has_min) { + val_out = max_override(val_out, min); + } + if (has_max) { + val_out = min_override(val_out, max); + } + return val_out; + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); }); return out; diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index b505f91bc85..08d898733e1 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) { #ifndef USE_ATEN_LIB TEST(OpClampOutTest, IntTensorTooSmallClampDies) { - // Cannot be represented by a uint32_t. + // Cannot be represented by a int32_t. expect_bad_clamp_value_dies(-2147483649); } TEST(OpClampOutTest, IntTensorTooLargeClampDies) { - // Cannot be represented by a uint32_t. + // Cannot be represented by a int32_t. expect_bad_clamp_value_dies(2147483648); } #endif