Skip to content

Commit c5afb4d

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in gelu (#7888)
Partial fix for #7748.
1 parent 2160ef9 commit c5afb4d

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

kernels/portable/cpu/op_gelu.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Tensor& gelu_out(
3737
ET_KERNEL_CHECK(
3838
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3939

40-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
40+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
4141
if (approximate == "tanh") {
4242
apply_unary_map_fn(
4343
[](const CTYPE x) {

kernels/test/op_gelu_test.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ TEST_F(OpGeluTest, FloatTensors) {
7070
test_gelu_execution<ScalarType::Float>();
7171
}
7272

73+
TEST_F(OpGeluTest, HalfTensors) {
74+
test_gelu_execution<ScalarType::Half>();
75+
}
76+
77+
TEST_F(OpGeluTest, BFloat16Tensors) {
78+
test_gelu_execution<ScalarType::BFloat16>();
79+
}
80+
7381
TEST_F(OpGeluTest, DoubleTensors) {
7482
if (!SupportedFeatures::get()->op_gelu_dtype_double) {
7583
GTEST_SKIP();

0 commit comments

Comments
 (0)