@@ -73,22 +73,20 @@ Tensor& select_scatter_out(
73
73
ScalarType in_type = in.scalar_type ();
74
74
ScalarType src_type = src.scalar_type ();
75
75
76
- ET_SWITCH_REAL_TYPES_AND (
77
- Bool, in_type, ctx, " select_scatter.out" , CTYPE, [&]() {
78
- ET_SWITCH_REAL_TYPES_AND (
79
- Bool, src_type, ctx, " select_scatter.out" , CTYPE_SRC, [&]() {
80
- CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
81
- const CTYPE_SRC* const src_data = src.const_data_ptr <CTYPE_SRC>();
82
-
83
- for (size_t i = 0 ; i < leading_dims; ++i) {
84
- for (size_t j = 0 ; j < trailing_stride; ++j) {
85
- out_data[start_offset + i * out_step + j] =
86
- convert<CTYPE, CTYPE_SRC>(
87
- src_data[i * trailing_stride + j]);
88
- }
89
- }
90
- });
91
- });
76
+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " select_scatter.out" , CTYPE, [&]() {
77
+ ET_SWITCH_REALHBBF16_TYPES (
78
+ src_type, ctx, " select_scatter.out" , CTYPE_SRC, [&]() {
79
+ CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
80
+ const CTYPE_SRC* const src_data = src.const_data_ptr <CTYPE_SRC>();
81
+
82
+ for (size_t i = 0 ; i < leading_dims; ++i) {
83
+ for (size_t j = 0 ; j < trailing_stride; ++j) {
84
+ out_data[start_offset + i * out_step + j] =
85
+ convert<CTYPE, CTYPE_SRC>(src_data[i * trailing_stride + j]);
86
+ }
87
+ }
88
+ });
89
+ });
92
90
93
91
return out;
94
92
}
0 commit comments