Skip to content

Commit 5dcfc0f

Browse files
salilsdesaifacebook-github-bot
authored andcommitted
Dtype compliance: clamp (#83)
Summary: Pull Request resolved: #83 Reland of D47573238 (Adjusting clamp portable op have near-complete Dtype compliance with aten version), but with the supernova build fixed. In the original diff, we were doing something like: ``` if isFloatingType(out_type) assert (double(min_val) >= out_type_min_value) ``` One problem with this is that if out_type is ```long```, then this comparison causes the long min value to be converted to double, which is unsafe. Even though it's impossible for this to happen when running the code due to the if statement, the compiler isn't smart enough to know that, so it was giving errors. The fix is to wrap these checks within ET_SWITCH_INT_TYPES or ET_SWITCH_FLOAT_TYPES macros within the if statements Differential Revision: D48491102 fbshipit-source-id: 3b0d0370a86e656a33b4e6acdab7ca9d3d622563
1 parent 50f8971 commit 5dcfc0f

File tree

2 files changed

+92
-26
lines changed

2 files changed

+92
-26
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,43 @@ using Scalar = exec_aten::Scalar;
7373
using ScalarType = exec_aten::ScalarType;
7474
using Tensor = exec_aten::Tensor;
7575

76+
namespace {
77+
78+
template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
79+
/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
80+
bool is_out_of_bounds(CTYPE_VAL val) {
81+
const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
82+
return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
83+
val_cast > std::numeric_limits<CTYPE_OUT>::max();
84+
}
85+
86+
void check_bounds(
87+
const Scalar& val_scalar,
88+
const torch::executor::native::ScalarType& val_type,
89+
const torch::executor::native::ScalarType& out_type,
90+
const char* val_name) {
91+
ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp", CTYPE_VAL, [&]() {
92+
CTYPE_VAL val = 0;
93+
ET_EXTRACT_SCALAR(val_scalar, val);
94+
if (isIntegralType(out_type, /*includeBool=*/false)) {
95+
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
96+
if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
97+
ET_CHECK_MSG(false, "%s value out of bounds", val_name);
98+
}
99+
});
100+
} else if (isFloatingType(out_type)) {
101+
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
102+
if (std::isfinite(val) &&
103+
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
104+
ET_CHECK_MSG(false, "%s value out of bounds", val_name);
105+
}
106+
});
107+
}
108+
});
109+
}
110+
111+
} // namespace
112+
76113
Tensor& clamp_out(
77114
RuntimeContext& ctx,
78115
const Tensor& in,
@@ -84,38 +121,67 @@ Tensor& clamp_out(
84121
Error err = resize_tensor(out, in.sizes());
85122
ET_CHECK_MSG(err == Error::Ok, "Could not resize output");
86123

87-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out);
124+
ScalarType in_type = in.scalar_type();
125+
ScalarType min_type = in_type;
126+
ScalarType max_type = in_type;
127+
ScalarType common_type = in_type;
128+
ScalarType out_type = out.scalar_type();
129+
130+
bool has_min = min_opt.has_value();
131+
if (has_min) {
132+
min_type = utils::get_scalar_dtype(min_opt.value());
133+
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
134+
check_bounds(min_opt.value(), min_type, out_type, "minimum");
135+
}
136+
bool has_max = max_opt.has_value();
137+
if (has_max) {
138+
max_type = utils::get_scalar_dtype(max_opt.value());
139+
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
140+
check_bounds(max_opt.value(), max_type, out_type, "maximum");
141+
}
88142

89-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() {
143+
ET_CHECK_MSG(
144+
has_min || has_max, "At least one of 'min' or 'max' must not be None");
145+
146+
ET_CHECK(common_type == out_type);
147+
148+
ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
90149
// Extract optional min value
91-
CTYPE min = 0;
92-
bool has_min = min_opt.has_value();
150+
CTYPE_OUT min = 0;
93151
if (has_min) {
94-
bool ok = utils::extract_scalar<CTYPE>(min_opt.value(), &min);
95-
ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range");
152+
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() {
153+
CTYPE_MIN min_val = 0;
154+
ET_EXTRACT_SCALAR(min_opt.value(), min_val);
155+
min = static_cast<CTYPE_OUT>(min_val);
156+
});
96157
}
158+
97159
// Extract optional max value
98-
CTYPE max = 0;
99-
bool has_max = max_opt.has_value();
160+
CTYPE_OUT max = 0;
100161
if (has_max) {
101-
bool ok = utils::extract_scalar<CTYPE>(max_opt.value(), &max);
102-
ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range");
162+
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() {
163+
CTYPE_MAX max_val = 0;
164+
ET_EXTRACT_SCALAR(max_opt.value(), max_val);
165+
max = static_cast<CTYPE_OUT>(max_val);
166+
});
103167
}
104168

105-
apply_unary_map_fn(
106-
[has_min, min, has_max, max](const CTYPE val_in) {
107-
CTYPE val_out = val_in;
108-
if (has_min) {
109-
val_out = max_override(val_out, min);
110-
}
111-
if (has_max) {
112-
val_out = min_override(val_out, max);
113-
}
114-
return val_out;
115-
},
116-
in.const_data_ptr<CTYPE>(),
117-
out.mutable_data_ptr<CTYPE>(),
118-
in.numel());
169+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
170+
apply_unary_map_fn(
171+
[has_min, min, has_max, max](const CTYPE_IN val_in) {
172+
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
173+
if (has_min) {
174+
val_out = max_override(val_out, min);
175+
}
176+
if (has_max) {
177+
val_out = min_override(val_out, max);
178+
}
179+
return val_out;
180+
},
181+
in.const_data_ptr<CTYPE_IN>(),
182+
out.mutable_data_ptr<CTYPE_OUT>(),
183+
in.numel());
184+
});
119185
});
120186

121187
return out;

kernels/test/op_clamp_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) {
303303

304304
#ifndef USE_ATEN_LIB
305305
TEST(OpClampOutTest, IntTensorTooSmallClampDies) {
306-
// Cannot be represented by a uint32_t.
306+
// Cannot be represented by a int32_t.
307307
expect_bad_clamp_value_dies<ScalarType::Int>(-2147483649);
308308
}
309309

310310
TEST(OpClampOutTest, IntTensorTooLargeClampDies) {
311-
// Cannot be represented by a uint32_t.
311+
// Cannot be represented by a int32_t.
312312
expect_bad_clamp_value_dies<ScalarType::Int>(2147483648);
313313
}
314314
#endif

0 commit comments

Comments
 (0)