Skip to content

Commit 48ac4ee

Browse files
committed
Support Half/BFloat16 in any
Partial fix for #7748. ghstack-source-id: 02a1dc7 ghstack-comment-id: 2599483099 Pull Request resolved: #7769
1 parent 9836b39 commit 48ac4ee

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

kernels/portable/cpu/op_any.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Tensor& any_all_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
2929
ScalarType out_type = out.scalar_type();
3030
constexpr auto name = "any.all_out";
3131

32-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
32+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
3333
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
3434
const auto data_in = in.const_data_ptr<CTYPE_IN>();
3535
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
@@ -78,7 +78,7 @@ Tensor& any_dims_out(
7878
ScalarType out_type = out.scalar_type();
7979
constexpr auto name = "any.dims_out";
8080

81-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
81+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
8282
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
8383
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
8484
if (dim_list.has_value() && dim_list.value().empty()) {
@@ -135,7 +135,7 @@ Tensor& any_out(
135135
ScalarType out_type = out.scalar_type();
136136
constexpr auto name = "any.out";
137137

138-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
138+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
139139
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
140140
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
141141
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {

kernels/test/op_any_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ TEST_F(OpAnyOutTest, InvalidDtypeDies) {
120120

121121
TEST_F(OpAnyOutTest, AllRealInputTypePasses) {
122122
#define TEST_ENTRY(ctype, dtype) test_any_all_out<ScalarType::dtype>();
123-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
123+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
124124
#undef TEST_ENTRY
125125
}
126126

0 commit comments

Comments
 (0)