@@ -67,50 +67,53 @@ Tensor& avg_pool2d_out(
67
67
out);
68
68
69
69
ScalarType in_type = in.scalar_type ();
70
- ET_SWITCH_FLOAT_TYPES_AND (Long, in_type, ctx, " avg_pool2d.out" , CTYPE, [&]() {
71
- if (divisor_override.has_value ()) {
72
- int64_t divisor = divisor_override.value ();
73
- // If divisor_override is specified, then we don't need to use `count` in
74
- // the calculation. Simply sum x / divisor to get the output.
75
- apply_kernel_2d_reduce_then_map_fn<CTYPE>(
76
- [](const CTYPE in_val,
77
- int64_t in_idx,
78
- CTYPE accum,
79
- int64_t accum_idx) {
80
- // Average pooling does not track indexes, so return 0 for accum_idx
81
- return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
82
- },
83
- [divisor](const int64_t count, const CTYPE accum) {
84
- return accum / static_cast <CTYPE>(divisor);
85
- },
86
- count_include_pad,
87
- in,
88
- kernel_size,
89
- stride,
90
- padding,
91
- {},
92
- out);
93
- } else {
94
- apply_kernel_2d_reduce_then_map_fn<CTYPE>(
95
- [](const CTYPE in_val,
96
- int64_t in_idx,
97
- CTYPE accum,
98
- int64_t accum_idx) {
99
- // Average pooling does not track indexes, so return 0 for accum_idx
100
- return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
101
- },
102
- [](const int64_t count, const CTYPE accum) {
103
- return accum / static_cast <CTYPE>(count);
104
- },
105
- count_include_pad,
106
- in,
107
- kernel_size,
108
- stride,
109
- padding,
110
- {},
111
- out);
112
- }
113
- });
70
+ ET_SWITCH_FLOATHBF16_TYPES_AND (
71
+ Long, in_type, ctx, " avg_pool2d.out" , CTYPE, [&]() {
72
+ if (divisor_override.has_value ()) {
73
+ int64_t divisor = divisor_override.value ();
74
+ // If divisor_override is specified, then we don't need to use `count`
75
+ // in the calculation. Simply sum x / divisor to get the output.
76
+ apply_kernel_2d_reduce_then_map_fn<CTYPE>(
77
+ [](const CTYPE in_val,
78
+ int64_t in_idx,
79
+ CTYPE accum,
80
+ int64_t accum_idx) {
81
+ // Average pooling does not track indexes, so return 0 for
82
+ // accum_idx
83
+ return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
84
+ },
85
+ [divisor](const int64_t count, const CTYPE accum) {
86
+ return accum / static_cast <CTYPE>(divisor);
87
+ },
88
+ count_include_pad,
89
+ in,
90
+ kernel_size,
91
+ stride,
92
+ padding,
93
+ {},
94
+ out);
95
+ } else {
96
+ apply_kernel_2d_reduce_then_map_fn<CTYPE>(
97
+ [](const CTYPE in_val,
98
+ int64_t in_idx,
99
+ CTYPE accum,
100
+ int64_t accum_idx) {
101
+ // Average pooling does not track indexes, so return 0 for
102
+ // accum_idx
103
+ return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
104
+ },
105
+ [](const int64_t count, const CTYPE accum) {
106
+ return accum / static_cast <CTYPE>(count);
107
+ },
108
+ count_include_pad,
109
+ in,
110
+ kernel_size,
111
+ stride,
112
+ padding,
113
+ {},
114
+ out);
115
+ }
116
+ });
114
117
115
118
return out;
116
119
}
0 commit comments