Skip to content

Commit 7ccfc94

Browse files
authored
Support Half/BFloat16 in avg_pool_2d (#7794)
Partial fix for #7748.
1 parent e2c64ab commit 7ccfc94

File tree

3 files changed

+1142
-1045
lines changed

3 files changed

+1142
-1045
lines changed

kernels/portable/cpu/op_avg_pool2d.cpp

+47-44
Original file line numberDiff line numberDiff line change
@@ -67,50 +67,53 @@ Tensor& avg_pool2d_out(
6767
out);
6868

6969
ScalarType in_type = in.scalar_type();
70-
ET_SWITCH_FLOAT_TYPES_AND(Long, in_type, ctx, "avg_pool2d.out", CTYPE, [&]() {
71-
if (divisor_override.has_value()) {
72-
int64_t divisor = divisor_override.value();
73-
// If divisor_override is specified, then we don't need to use `count` in
74-
// the calculation. Simply sum x / divisor to get the output.
75-
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
76-
[](const CTYPE in_val,
77-
int64_t in_idx,
78-
CTYPE accum,
79-
int64_t accum_idx) {
80-
// Average pooling does not track indexes, so return 0 for accum_idx
81-
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
82-
},
83-
[divisor](const int64_t count, const CTYPE accum) {
84-
return accum / static_cast<CTYPE>(divisor);
85-
},
86-
count_include_pad,
87-
in,
88-
kernel_size,
89-
stride,
90-
padding,
91-
{},
92-
out);
93-
} else {
94-
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
95-
[](const CTYPE in_val,
96-
int64_t in_idx,
97-
CTYPE accum,
98-
int64_t accum_idx) {
99-
// Average pooling does not track indexes, so return 0 for accum_idx
100-
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
101-
},
102-
[](const int64_t count, const CTYPE accum) {
103-
return accum / static_cast<CTYPE>(count);
104-
},
105-
count_include_pad,
106-
in,
107-
kernel_size,
108-
stride,
109-
padding,
110-
{},
111-
out);
112-
}
113-
});
70+
ET_SWITCH_FLOATHBF16_TYPES_AND(
71+
Long, in_type, ctx, "avg_pool2d.out", CTYPE, [&]() {
72+
if (divisor_override.has_value()) {
73+
int64_t divisor = divisor_override.value();
74+
// If divisor_override is specified, then we don't need to use `count`
75+
// in the calculation. Simply sum x / divisor to get the output.
76+
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
77+
[](const CTYPE in_val,
78+
int64_t in_idx,
79+
CTYPE accum,
80+
int64_t accum_idx) {
81+
// Average pooling does not track indexes, so return 0 for
82+
// accum_idx
83+
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
84+
},
85+
[divisor](const int64_t count, const CTYPE accum) {
86+
return accum / static_cast<CTYPE>(divisor);
87+
},
88+
count_include_pad,
89+
in,
90+
kernel_size,
91+
stride,
92+
padding,
93+
{},
94+
out);
95+
} else {
96+
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
97+
[](const CTYPE in_val,
98+
int64_t in_idx,
99+
CTYPE accum,
100+
int64_t accum_idx) {
101+
// Average pooling does not track indexes, so return 0 for
102+
// accum_idx
103+
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
104+
},
105+
[](const int64_t count, const CTYPE accum) {
106+
return accum / static_cast<CTYPE>(count);
107+
},
108+
count_include_pad,
109+
in,
110+
kernel_size,
111+
stride,
112+
padding,
113+
{},
114+
out);
115+
}
116+
});
114117

115118
return out;
116119
}

0 commit comments

Comments
 (0)