Skip to content

Dtype compliance: clamp #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 90 additions & 24 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,43 @@ using Scalar = exec_aten::Scalar;
using ScalarType = exec_aten::ScalarType;
using Tensor = exec_aten::Tensor;

namespace {

template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
bool is_out_of_bounds(CTYPE_VAL val) {
const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
val_cast > std::numeric_limits<CTYPE_OUT>::max();
}

void check_bounds(
const Scalar& val_scalar,
const torch::executor::native::ScalarType& val_type,
const torch::executor::native::ScalarType& out_type,
const char* val_name) {
ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp", CTYPE_VAL, [&]() {
CTYPE_VAL val = 0;
ET_EXTRACT_SCALAR(val_scalar, val);
if (isIntegralType(out_type, /*includeBool=*/false)) {
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
ET_CHECK_MSG(false, "%s value out of bounds", val_name);
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
if (std::isfinite(val) &&
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
ET_CHECK_MSG(false, "%s value out of bounds", val_name);
}
});
}
});
}

} // namespace

Tensor& clamp_out(
RuntimeContext& ctx,
const Tensor& in,
Expand All @@ -84,38 +121,67 @@ Tensor& clamp_out(
Error err = resize_tensor(out, in.sizes());
ET_CHECK_MSG(err == Error::Ok, "Could not resize output");

ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out);
ScalarType in_type = in.scalar_type();
ScalarType min_type = in_type;
ScalarType max_type = in_type;
ScalarType common_type = in_type;
ScalarType out_type = out.scalar_type();

bool has_min = min_opt.has_value();
if (has_min) {
min_type = utils::get_scalar_dtype(min_opt.value());
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
check_bounds(min_opt.value(), min_type, out_type, "minimum");
}
bool has_max = max_opt.has_value();
if (has_max) {
max_type = utils::get_scalar_dtype(max_opt.value());
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
check_bounds(max_opt.value(), max_type, out_type, "maximum");
}

ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() {
ET_CHECK_MSG(
has_min || has_max, "At least one of 'min' or 'max' must not be None");

ET_CHECK(common_type == out_type);

ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
// Extract optional min value
CTYPE min = 0;
bool has_min = min_opt.has_value();
CTYPE_OUT min = 0;
if (has_min) {
bool ok = utils::extract_scalar<CTYPE>(min_opt.value(), &min);
ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range");
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() {
CTYPE_MIN min_val = 0;
ET_EXTRACT_SCALAR(min_opt.value(), min_val);
min = static_cast<CTYPE_OUT>(min_val);
});
}

// Extract optional max value
CTYPE max = 0;
bool has_max = max_opt.has_value();
CTYPE_OUT max = 0;
if (has_max) {
bool ok = utils::extract_scalar<CTYPE>(max_opt.value(), &max);
ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range");
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() {
CTYPE_MAX max_val = 0;
ET_EXTRACT_SCALAR(max_opt.value(), max_val);
max = static_cast<CTYPE_OUT>(max_val);
});
}

apply_unary_map_fn(
[has_min, min, has_max, max](const CTYPE val_in) {
CTYPE val_out = val_in;
if (has_min) {
val_out = max_override(val_out, min);
}
if (has_max) {
val_out = min_override(val_out, max);
}
return val_out;
},
in.const_data_ptr<CTYPE>(),
out.mutable_data_ptr<CTYPE>(),
in.numel());
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
apply_unary_map_fn(
[has_min, min, has_max, max](const CTYPE_IN val_in) {
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
if (has_min) {
val_out = max_override(val_out, min);
}
if (has_max) {
val_out = min_override(val_out, max);
}
return val_out;
},
in.const_data_ptr<CTYPE_IN>(),
out.mutable_data_ptr<CTYPE_OUT>(),
in.numel());
});
});

return out;
Expand Down
4 changes: 2 additions & 2 deletions kernels/test/op_clamp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) {

#ifndef USE_ATEN_LIB
TEST(OpClampOutTest, IntTensorTooSmallClampDies) {
// Cannot be represented by a uint32_t.
// Cannot be represented by a int32_t.
expect_bad_clamp_value_dies<ScalarType::Int>(-2147483649);
}

TEST(OpClampOutTest, IntTensorTooLargeClampDies) {
// Cannot be represented by a uint32_t.
// Cannot be represented by a int32_t.
expect_bad_clamp_value_dies<ScalarType::Int>(2147483648);
}
#endif
Expand Down