Skip to content

Commit 58054e3

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in glu (#7824)
Partial fix for #7748.
1 parent 21358cf commit 58054e3

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

kernels/portable/cpu/op_glu.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,10 @@ Tensor& glu_out(
155155
const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim;
156156
const auto in_dtype = self.scalar_type();
157157

158-
ET_SWITCH_FLOAT_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
159-
if (out.scalar_type() == ScalarType::Float) {
160-
glu_out_tensor<CTYPE_IN, float>(self, non_negative_dim, out);
161-
} else {
162-
glu_out_tensor<CTYPE_IN, double>(self, non_negative_dim, out);
163-
}
158+
ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
159+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
160+
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
161+
});
164162
});
165163

166164
return out;

kernels/test/op_glu_test.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,28 @@ class OpGluOutTest : public OperatorTest {
117117
TEST_F(OpGluOutTest, AllInputFloatOutputSupport) {
118118
#define TEST_ENTRY(ctype, dtype) \
119119
test_glu_out<ScalarType::dtype, ScalarType::Float>();
120-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
120+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
121121
#undef TEST_ENTRY
122122
}
123123

124124
TEST_F(OpGluOutTest, AllInputDoubleOutputSupport) {
125125
#define TEST_ENTRY(ctype, dtype) \
126126
test_glu_out<ScalarType::dtype, ScalarType::Double>();
127-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
127+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
128+
#undef TEST_ENTRY
129+
}
130+
131+
TEST_F(OpGluOutTest, AllInputHalfOutputSupport) {
132+
#define TEST_ENTRY(ctype, dtype) \
133+
test_glu_out<ScalarType::dtype, ScalarType::Half>();
134+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
135+
#undef TEST_ENTRY
136+
}
137+
138+
TEST_F(OpGluOutTest, AllInputBFloat16OutputSupport) {
139+
#define TEST_ENTRY(ctype, dtype) \
140+
test_glu_out<ScalarType::dtype, ScalarType::BFloat16>();
141+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
128142
#undef TEST_ENTRY
129143
}
130144

0 commit comments

Comments
 (0)