@@ -73,6 +73,43 @@ using Scalar = exec_aten::Scalar;
73
73
using ScalarType = exec_aten::ScalarType;
74
74
using Tensor = exec_aten::Tensor;
75
75
76
+ namespace {
77
+
78
+ template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
79
+ /* * Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
80
+ bool is_out_of_bounds (CTYPE_VAL val) {
81
+ const CTYPE_CAST val_cast = static_cast <CTYPE_CAST>(val);
82
+ return val_cast < std::numeric_limits<CTYPE_OUT>::lowest () ||
83
+ val_cast > std::numeric_limits<CTYPE_OUT>::max ();
84
+ }
85
+
86
+ void check_bounds (
87
+ const Scalar& val_scalar,
88
+ const torch::executor::native::ScalarType& val_type,
89
+ const torch::executor::native::ScalarType& out_type,
90
+ const char * val_name) {
91
+ ET_SWITCH_SCALAR_OBJ_TYPES (val_type, ctx, " clamp" , CTYPE_VAL, [&]() {
92
+ CTYPE_VAL val = 0 ;
93
+ ET_EXTRACT_SCALAR (val_scalar, val);
94
+ if (isIntegralType (out_type, /* includeBool=*/ false )) {
95
+ ET_SWITCH_INT_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
96
+ if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long >(val)) {
97
+ ET_CHECK_MSG (false , " %s value out of bounds" , val_name);
98
+ }
99
+ });
100
+ } else if (isFloatingType (out_type)) {
101
+ ET_SWITCH_FLOAT_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
102
+ if (std::isfinite (val) &&
103
+ is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double >(val)) {
104
+ ET_CHECK_MSG (false , " %s value out of bounds" , val_name);
105
+ }
106
+ });
107
+ }
108
+ });
109
+ }
110
+
111
+ } // namespace
112
+
76
113
Tensor& clamp_out (
77
114
RuntimeContext& ctx,
78
115
const Tensor& in,
@@ -84,38 +121,67 @@ Tensor& clamp_out(
84
121
Error err = resize_tensor (out, in.sizes ());
85
122
ET_CHECK_MSG (err == Error::Ok, " Could not resize output" );
86
123
87
- ET_CHECK_SAME_SHAPE_AND_DTYPE2 (in, out);
124
+ ScalarType in_type = in.scalar_type ();
125
+ ScalarType min_type = in_type;
126
+ ScalarType max_type = in_type;
127
+ ScalarType common_type = in_type;
128
+ ScalarType out_type = out.scalar_type ();
129
+
130
+ bool has_min = min_opt.has_value ();
131
+ if (has_min) {
132
+ min_type = utils::get_scalar_dtype (min_opt.value ());
133
+ common_type = utils::promote_type_with_scalar (common_type, min_opt.value ());
134
+ check_bounds (min_opt.value (), min_type, out_type, " minimum" );
135
+ }
136
+ bool has_max = max_opt.has_value ();
137
+ if (has_max) {
138
+ max_type = utils::get_scalar_dtype (max_opt.value ());
139
+ common_type = utils::promote_type_with_scalar (common_type, max_opt.value ());
140
+ check_bounds (max_opt.value (), max_type, out_type, " maximum" );
141
+ }
88
142
89
- ET_SWITCH_REAL_TYPES (in.scalar_type (), ctx, " clamp" , CTYPE, [&]() {
143
+ ET_CHECK_MSG (
144
+ has_min || has_max, " At least one of 'min' or 'max' must not be None" );
145
+
146
+ ET_CHECK (common_type == out_type);
147
+
148
+ ET_SWITCH_REAL_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
90
149
// Extract optional min value
91
- CTYPE min = 0 ;
92
- bool has_min = min_opt.has_value ();
150
+ CTYPE_OUT min = 0 ;
93
151
if (has_min) {
94
- bool ok = utils::extract_scalar<CTYPE>(min_opt.value (), &min);
95
- ET_CHECK_MSG (ok, " Invalid min value: wrong type or out of range" );
152
+ ET_SWITCH_SCALAR_OBJ_TYPES (min_type, ctx, " clamp" , CTYPE_MIN, [&]() {
153
+ CTYPE_MIN min_val = 0 ;
154
+ ET_EXTRACT_SCALAR (min_opt.value (), min_val);
155
+ min = static_cast <CTYPE_OUT>(min_val);
156
+ });
96
157
}
158
+
97
159
// Extract optional max value
98
- CTYPE max = 0 ;
99
- bool has_max = max_opt.has_value ();
160
+ CTYPE_OUT max = 0 ;
100
161
if (has_max) {
101
- bool ok = utils::extract_scalar<CTYPE>(max_opt.value (), &max);
102
- ET_CHECK_MSG (ok, " Invalid max value: wrong type or out of range" );
162
+ ET_SWITCH_SCALAR_OBJ_TYPES (max_type, ctx, " clamp" , CTYPE_MAX, [&]() {
163
+ CTYPE_MAX max_val = 0 ;
164
+ ET_EXTRACT_SCALAR (max_opt.value (), max_val);
165
+ max = static_cast <CTYPE_OUT>(max_val);
166
+ });
103
167
}
104
168
105
- apply_unary_map_fn (
106
- [has_min, min, has_max, max](const CTYPE val_in) {
107
- CTYPE val_out = val_in;
108
- if (has_min) {
109
- val_out = max_override (val_out, min);
110
- }
111
- if (has_max) {
112
- val_out = min_override (val_out, max);
113
- }
114
- return val_out;
115
- },
116
- in.const_data_ptr <CTYPE>(),
117
- out.mutable_data_ptr <CTYPE>(),
118
- in.numel ());
169
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, " clamp" , CTYPE_IN, [&]() {
170
+ apply_unary_map_fn (
171
+ [has_min, min, has_max, max](const CTYPE_IN val_in) {
172
+ CTYPE_OUT val_out = static_cast <CTYPE_OUT>(val_in);
173
+ if (has_min) {
174
+ val_out = max_override (val_out, min);
175
+ }
176
+ if (has_max) {
177
+ val_out = min_override (val_out, max);
178
+ }
179
+ return val_out;
180
+ },
181
+ in.const_data_ptr <CTYPE_IN>(),
182
+ out.mutable_data_ptr <CTYPE_OUT>(),
183
+ in.numel ());
184
+ });
119
185
});
120
186
121
187
return out;
0 commit comments