Skip to content

Commit 1a66163

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[quant] Support integer implementations for adaptive_avg_pool2d (pytorch#104226)
Summary: This is needed for representing quantized model in pt2 export quantization flow Test Plan: tested by opinfo, python test/test_ops.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#104226 Approved by: https://github.com/jgong5, https://github.com/andrewor14
1 parent 98e14ac commit 1a66163

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ void adaptive_avg_pool2d_kernel_impl(
367367
IntArrayRef output_size) {
368368
switch (input.suggest_memory_format()) {
369369
case at::MemoryFormat::Contiguous: {
370-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "adaptive_avg_pool2d", [&] {
370+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "adaptive_avg_pool2d", [&] {
371371
if (input.scalar_type() == ScalarType::BFloat16) {
372372
cpu_adaptive_avg_pool<BFloat16, /*accscalar_t*/float>(output, input, output_size);
373373
} else {
@@ -377,7 +377,7 @@ void adaptive_avg_pool2d_kernel_impl(
377377
break;
378378
}
379379
case at::MemoryFormat::ChannelsLast: {
380-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{
380+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{
381381
cpu_adaptive_avg_pool_channels_last<scalar_t>(output, input, output_size);
382382
});
383383
break;

test/nn/test_pooling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,9 @@ def test_adaptive_avg_pool3d_output_size_one(self, device):
754754
def test_adaptive_pooling_no_suppot_input(self, device, dtype):
755755
for numel in (2, 3):
756756
for pool_type in ('Max', 'Avg'):
757+
# adapative_avg_pool2d for int is implemented
758+
if numel == 2 and pool_type == 'Avg':
759+
continue
757760
cls_name = 'Adaptive{}Pool{}d'.format(pool_type, numel)
758761
module_cls = getattr(nn, cls_name)
759762
output_size = (2,) * numel

torch/_decomp/decompositions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,10 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
18421842
f"non-batch dimensions, but input has shape {tuple(shape)}.",
18431843
)
18441844

1845+
# TODO: decompose integer path
1846+
if input.dtype in [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64]:
1847+
return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
1848+
18451849
# Optimisation (we should also do this in the kernel implementation)
18461850
if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
18471851
stride = tuple(i // o for i, o in zip(shape[-2:], output_size))

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12009,7 +12009,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1200912009
error_inputs_func=error_inputs_adaptive_avg_pool1d,
1201012010
sample_inputs_func=sample_inputs_adaptive_avg_pool1d),
1201112011
OpInfo('nn.functional.adaptive_avg_pool2d',
12012-
dtypes=floating_types_and(torch.bfloat16),
12012+
dtypes=all_types_and(torch.bfloat16),
1201312013
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
1201412014
decorators=(
1201512015
# RuntimeError:

0 commit comments

Comments
 (0)