Skip to content

Commit 2aa6d3a

Browse files
committed
Support Half/BFloat16 in glu
Partial fix for #7748. ghstack-source-id: b25fafc ghstack-comment-id: 2605988746 Pull Request resolved: #7824
1 parent 466d98f commit 2aa6d3a

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

kernels/portable/cpu/op_glu.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,10 @@ Tensor& glu_out(
155155
const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim;
156156
const auto in_dtype = self.scalar_type();
157157

158-
ET_SWITCH_FLOAT_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
159-
if (out.scalar_type() == ScalarType::Float) {
160-
glu_out_tensor<CTYPE_IN, float>(self, non_negative_dim, out);
161-
} else {
162-
glu_out_tensor<CTYPE_IN, double>(self, non_negative_dim, out);
163-
}
158+
ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
159+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
160+
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
161+
});
164162
});
165163

166164
return out;

kernels/test/op_glu_test.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ TEST_F(OpGluOutTest, AllInputDoubleOutputSupport) {
128128
#undef TEST_ENTRY
129129
}
130130

131+
TEST_F(OpGluOutTest, AllInputHalfOutputSupport) {
132+
#define TEST_ENTRY(ctype, dtype) \
133+
test_glu_out<ScalarType::dtype, ScalarType::Half>();
134+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
135+
#undef TEST_ENTRY
136+
}
137+
138+
TEST_F(OpGluOutTest, AllInputBFloat16OutputSupport) {
139+
#define TEST_ENTRY(ctype, dtype) \
140+
test_glu_out<ScalarType::dtype, ScalarType::BFloat16>();
141+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
142+
#undef TEST_ENTRY
143+
}
144+
131145
TEST_F(OpGluOutTest, InfinityAndNANTest) {
132146
TensorFactory<ScalarType::Float> tf;
133147
const std::vector<int32_t> sizes = {4, 2};

0 commit comments

Comments
 (0)