Skip to content

Commit c8c45bd

Browse files
committed
Support Half/BFloat16 in hardtanh
Partial fix for #7748. ghstack-source-id: ce0dcbf ghstack-comment-id: 2610865043 Pull Request resolved: #7899
1 parent 65e3c00 commit c8c45bd

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

kernels/portable/cpu/op_hardtanh.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Tensor& hardtanh_out(
4646

4747
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4848

49-
ET_SWITCH_REAL_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
49+
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
5050
CTYPE min_casted;
5151
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() {
5252
CTYPE_MIN min_val;

kernels/portable/cpu/util/math_util.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ INT_T max_override(INT_T a, INT_T b) {
9696

9797
template <
9898
typename T,
99-
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
100-
type = true>
99+
typename std::enable_if_t<
100+
std::is_same_v<T, exec_aten::Half> ||
101+
std::is_same_v<T, exec_aten::BFloat16>,
102+
bool> = true>
101103
T min_override(T a, T b) {
102104
const auto float_a = static_cast<float>(a);
103105
if (std::isnan(float_a)) {
@@ -116,8 +118,10 @@ T min_override(T a, T b) {
116118

117119
template <
118120
typename T,
119-
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
120-
type = true>
121+
typename std::enable_if_t<
122+
std::is_same_v<T, exec_aten::Half> ||
123+
std::is_same_v<T, exec_aten::BFloat16>,
124+
bool> = true>
121125
T max_override(T a, T b) {
122126
const auto float_a = static_cast<float>(a);
123127
if (std::isnan(float_a)) {

kernels/test/op_hardtanh_test.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,31 @@ class OpHardTanhTest : public OperatorTest {
3030
return torch::executor::aten::hardtanh_outf(
3131
context_, self, min_val, max_val, out);
3232
}
33-
};
3433

35-
TEST_F(OpHardTanhTest, SanityCheck) {
36-
TensorFactory<ScalarType::Float> tf;
37-
Tensor in = tf.ones({2, 2});
38-
Tensor out = tf.zeros({2, 2});
34+
template <typename CTYPE, ScalarType DTYPE>
35+
void test_dtype() {
36+
TensorFactory<DTYPE> tf;
37+
CTYPE lowest_test_element;
38+
CTYPE lower_bound;
39+
if constexpr (std::numeric_limits<CTYPE>::is_signed) {
40+
lowest_test_element = -3;
41+
lower_bound = -2;
42+
} else {
43+
lowest_test_element = 0;
44+
lower_bound = 0;
45+
}
46+
Tensor in = tf.make({2, 2}, {lowest_test_element, 0, 1, 100});
47+
Tensor out = tf.zeros({2, 2});
48+
49+
Tensor ret = op_hardtanh_out(in, lower_bound, 2, out);
3950

40-
Tensor ret = op_hardtanh_out(in, -2, 2, out);
51+
EXPECT_TENSOR_EQ(out, ret);
52+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2}));
53+
}
54+
};
4155

42-
EXPECT_TENSOR_EQ(out, ret);
43-
EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
56+
TEST_F(OpHardTanhTest, SanityCheck) {
57+
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
58+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
59+
#undef TEST_ENTRY
4460
}

0 commit comments

Comments
 (0)