Skip to content

Commit ae38be7

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in logit (#7890)
Partial fix for #7748.
1 parent c8f48db commit ae38be7

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

kernels/portable/cpu/op_logit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Tensor& logit_out(
3535

3636
ScalarType in_type = in.scalar_type();
3737
ScalarType out_type = out.scalar_type();
38-
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "logit.out", CTYPE_IN, [&] {
38+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "logit.out", CTYPE_IN, [&] {
3939
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "logit.out", CTYPE_OUT, [&] {
4040
apply_unary_map_fn(
4141
[eps](const CTYPE_IN val_in) {

kernels/test/op_logit_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,27 +100,27 @@ void OpLogitOutTest::
100100
TEST_F(OpLogitOutTest, AllRealInputFloatOutputSupport) {
101101
#define TEST_ENTRY(ctype, dtype) \
102102
test_integer_logit_out<ScalarType::dtype, ScalarType::Float>();
103-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
103+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
104104
#undef TEST_ENTRY
105105
}
106106

107107
TEST_F(OpLogitOutTest, AllRealInputDoubleOutputSupport) {
108108
#define TEST_ENTRY(ctype, dtype) \
109109
test_integer_logit_out<ScalarType::dtype, ScalarType::Double>();
110-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
110+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
111111
#undef TEST_ENTRY
112112
}
113113
TEST_F(OpLogitOutTest, AllRealInputFloatOutputSupportEpsSet) {
114114
#define TEST_ENTRY(ctype, dtype) \
115115
test_integer_logit_out_eps_set<ScalarType::dtype, ScalarType::Float>();
116-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
116+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
117117
#undef TEST_ENTRY
118118
}
119119

120120
TEST_F(OpLogitOutTest, AllRealInputDoubleOutputSupportEpsSet) {
121121
#define TEST_ENTRY(ctype, dtype) \
122122
test_integer_logit_out_eps_set<ScalarType::dtype, ScalarType::Double>();
123-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
123+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
124124
#undef TEST_ENTRY
125125
}
126126

0 commit comments

Comments
 (0)