Skip to content

Commit c954fbc

Browse files
authored
Support Half/BFloat16 in tril (#7893)
Partial fix for #7748.
1 parent 64c2556 commit c954fbc

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

kernels/portable/cpu/op_tril.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Tensor& tril_out(
158158
clear_out(out);
159159

160160
ScalarType out_type = out.scalar_type();
161-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&]() {
161+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, __func__, CTYPE, [&]() {
162162
tril_kernel<CTYPE>(ctx, self, diagonal, out);
163163
});
164164

kernels/test/op_tril_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ class OpTrilTest : public OperatorTest {
727727
test_tril_out_multi_unequal_dim<ScalarType::DTYPE>(); \
728728
}
729729

730-
ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_GENERIC_TEST)
730+
ET_FORALL_REALHBBF16_TYPES(GENERATE_GENERIC_TEST)
731731

732732
// Create generic tests for real dtypes. Tensors have diverse values.
733733
#define GENERATE_REAL_TEST(_, DTYPE) \
@@ -738,7 +738,7 @@ ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_GENERIC_TEST)
738738
test_tril_out_randint_multi_unequal<ScalarType::DTYPE>(); \
739739
}
740740

741-
ET_FORALL_REAL_TYPES(GENERATE_REAL_TEST)
741+
ET_FORALL_REALHBBF16_TYPES(GENERATE_REAL_TEST)
742742

743743
TEST_F(OpTrilTest, InvalidInputShapesDies) {
744744
TensorFactory<ScalarType::Int> tf;

0 commit comments

Comments
 (0)