Skip to content

Commit 2f6683b

Browse files
salilsdesaifacebook-github-bot
authored andcommitted
Dtype compliance: clamp (#69)
Summary: Pull Request resolved: #69 Adjust clamp portable op have near-complete Dtype compliance with aten version. The Aten version of clamp has a weird quirk where it allows you to pass in a value for the min and/or max args which is below the normal range of the input/output tensor dataype when that datatype is uint8 specifically. But it doesn't allow ABOVE the range, and it doesn't allow below or above for any other datatype. We are choosing to leave a discrepancy between aten and portable by making uint8 behave like the rest of the datatypes for portable (not allowing below the range). This is already tested by the ByteTensorNegativeClampDies test (which is skipped when running the aten tests). Reviewed By: SS-JIA, manuelcandales Differential Revision: D47573238 fbshipit-source-id: bdd6a305085c1a7b648ff9a2d104a5f9e75a71bf
1 parent d1563b6 commit 2f6683b

File tree

2 files changed

+87
-26
lines changed

2 files changed

+87
-26
lines changed

kernels/portable/cpu/op_clamp.cpp

+85-24
Original file line numberDiff line numberDiff line change
@@ -84,38 +84,99 @@ Tensor& clamp_out(
8484
Error err = resize_tensor(out, in.sizes());
8585
ET_CHECK_MSG(err == Error::Ok, "Could not resize output");
8686

87-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out);
87+
ScalarType in_type = in.scalar_type();
88+
ScalarType min_type = in_type;
89+
ScalarType max_type = in_type;
90+
ScalarType common_type = in_type;
91+
ScalarType out_type = out.scalar_type();
92+
93+
bool has_min = min_opt.has_value();
94+
if (has_min) {
95+
min_type = utils::get_scalar_dtype(min_opt.value());
96+
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
97+
}
98+
bool has_max = max_opt.has_value();
99+
if (has_max) {
100+
max_type = utils::get_scalar_dtype(max_opt.value());
101+
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
102+
}
103+
104+
ET_CHECK_MSG(
105+
has_min || has_max, "At least one of 'min' or 'max' must not be None");
88106

89-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() {
107+
ET_CHECK(common_type == out_type);
108+
109+
ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
90110
// Extract optional min value
91-
CTYPE min = 0;
92-
bool has_min = min_opt.has_value();
111+
CTYPE_OUT min = 0;
93112
if (has_min) {
94-
bool ok = utils::extract_scalar<CTYPE>(min_opt.value(), &min);
95-
ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range");
113+
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() {
114+
CTYPE_MIN min_val = 0;
115+
ET_EXTRACT_SCALAR(min_opt.value(), min_val);
116+
if (isIntegralType(out_type, /*includeBool=*/false)) {
117+
if (static_cast<long>(min_val) <
118+
std::numeric_limits<CTYPE_OUT>::lowest() ||
119+
static_cast<long>(min_val) >
120+
std::numeric_limits<CTYPE_OUT>::max()) {
121+
ET_CHECK_MSG(false, "minimum value out of bounds");
122+
}
123+
}
124+
if (isFloatingType(out_type)) {
125+
if (std::isfinite(min_val) &&
126+
(static_cast<double>(min_val) <
127+
std::numeric_limits<CTYPE_OUT>::lowest() ||
128+
static_cast<double>(min_val) >
129+
std::numeric_limits<CTYPE_OUT>::max())) {
130+
ET_CHECK_MSG(false, "minimum value out of bounds");
131+
}
132+
}
133+
min = static_cast<CTYPE_OUT>(min_val);
134+
});
96135
}
136+
97137
// Extract optional max value
98-
CTYPE max = 0;
99-
bool has_max = max_opt.has_value();
138+
CTYPE_OUT max = 0;
100139
if (has_max) {
101-
bool ok = utils::extract_scalar<CTYPE>(max_opt.value(), &max);
102-
ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range");
103-
}
104-
105-
apply_unary_map_fn(
106-
[has_min, min, has_max, max](const CTYPE val_in) {
107-
CTYPE val_out = val_in;
108-
if (has_min) {
109-
val_out = max_override(val_out, min);
140+
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() {
141+
CTYPE_MAX max_val = 0;
142+
ET_EXTRACT_SCALAR(max_opt.value(), max_val);
143+
if (isIntegralType(out_type, /*includeBool=*/false)) {
144+
if (static_cast<long>(max_val) <
145+
std::numeric_limits<CTYPE_OUT>::lowest() ||
146+
static_cast<long>(max_val) >
147+
std::numeric_limits<CTYPE_OUT>::max()) {
148+
ET_CHECK_MSG(false, "maximum value out of bounds");
110149
}
111-
if (has_max) {
112-
val_out = min_override(val_out, max);
150+
}
151+
if (isFloatingType(out_type)) {
152+
if (std::isfinite(max_val) &&
153+
(static_cast<double>(max_val) <
154+
std::numeric_limits<CTYPE_OUT>::lowest() ||
155+
static_cast<double>(max_val) >
156+
std::numeric_limits<CTYPE_OUT>::max())) {
157+
ET_CHECK_MSG(false, "maximum value out of bounds");
113158
}
114-
return val_out;
115-
},
116-
in.const_data_ptr<CTYPE>(),
117-
out.mutable_data_ptr<CTYPE>(),
118-
in.numel());
159+
}
160+
max = static_cast<CTYPE_OUT>(max_val);
161+
});
162+
}
163+
164+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
165+
apply_unary_map_fn(
166+
[has_min, min, has_max, max](const CTYPE_IN val_in) {
167+
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
168+
if (has_min) {
169+
val_out = max_override(val_out, min);
170+
}
171+
if (has_max) {
172+
val_out = min_override(val_out, max);
173+
}
174+
return val_out;
175+
},
176+
in.const_data_ptr<CTYPE_IN>(),
177+
out.mutable_data_ptr<CTYPE_OUT>(),
178+
in.numel());
179+
});
119180
});
120181

121182
return out;

kernels/test/op_clamp_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) {
303303

304304
#ifndef USE_ATEN_LIB
305305
TEST(OpClampOutTest, IntTensorTooSmallClampDies) {
306-
// Cannot be represented by a uint32_t.
306+
// Cannot be represented by a int32_t.
307307
expect_bad_clamp_value_dies<ScalarType::Int>(-2147483649);
308308
}
309309

310310
TEST(OpClampOutTest, IntTensorTooLargeClampDies) {
311-
// Cannot be represented by a uint32_t.
311+
// Cannot be represented by a int32_t.
312312
expect_bad_clamp_value_dies<ScalarType::Int>(2147483648);
313313
}
314314
#endif

0 commit comments

Comments
 (0)