Skip to content

Unsqueeze unbatched input of avg_pool#2646

Merged
justinchuby merged 4 commits into
microsoft:mainfrom
wodesuck:patch-3
Oct 27, 2025
Merged

Unsqueeze unbatched input of avg_pool#2646
justinchuby merged 4 commits into
microsoft:mainfrom
wodesuck:patch-3

Conversation

@wodesuck
Copy link
Copy Markdown
Contributor

Onnx's AveragePool require input shape as N,C,H,W, but torch accept both N,C,H,W and C,H,W. Unsqueeze if input is unbatched, just like what max_pool does.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 20, 2025
@justinchuby
Copy link
Copy Markdown
Collaborator

@codecov
Copy link
Copy Markdown

codecov Bot commented Oct 20, 2025

Codecov Report

❌ Patch coverage is 63.63636% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.45%. Comparing base (8a94ad6) to head (f5ec077).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 63.63% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2646      +/-   ##
==========================================
- Coverage   70.46%   70.45%   -0.01%     
==========================================
  Files         224      224              
  Lines       26572    26577       +5     
  Branches     2637     2639       +2     
==========================================
+ Hits        18723    18724       +1     
- Misses       6928     6930       +2     
- Partials      921      923       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wodesuck
Copy link
Copy Markdown
Contributor Author

@justinchuby Test added.

Comment thread tests/function_libs/torch_lib/e2e_ops_tests.py Fixed
Comment thread tests/function_libs/torch_lib/e2e_ops_tests.py Fixed
Comment thread tests/function_libs/torch_lib/e2e_ops_tests.py Fixed
Comment thread tests/function_libs/torch_lib/e2e_ops_tests.py Fixed
Comment thread tests/function_libs/torch_lib/e2e_ops_tests.py Fixed
Comment thread tests/function_libs/torch_lib/e2e_ops_tests.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for unbatched input tensors in average pooling operations to match PyTorch's behavior. While ONNX's AveragePool requires NCHW format, PyTorch accepts both batched (NCHW) and unbatched (CHW) inputs. The changes handle unbatched inputs by automatically unsqueezing/squeezing dimensions, similar to the existing max_pool implementation.

Key Changes:

  • Introduced a helper function _aten_avg_pool_onnx that handles both batched and unbatched inputs
  • Refactored avg_pool1d, avg_pool2d, and avg_pool3d to use the new helper function
  • Added comprehensive tests covering all pooling dimensions with both batched and unbatched inputs

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxscript/function_libs/torch_lib/ops/nn.py Refactored avg_pool operations to support unbatched inputs via new helper function
tests/function_libs/torch_lib/e2e_ops_tests.py Added test cases for avg_pool operations with various input dimensions

Comment thread onnxscript/function_libs/torch_lib/ops/nn.py
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Minor comments and please check this page for lint: https://github.com/microsoft/onnxscript#coding-style

Comment thread onnxscript/function_libs/torch_lib/ops/nn.py Outdated
Comment thread onnxscript/function_libs/torch_lib/ops/nn.py Outdated
@titaiwangms
Copy link
Copy Markdown
Contributor

There is still something wrong with lint. Would you check?

@wodesuck
Copy link
Copy Markdown
Contributor Author

@titaiwangms Pylint says "torch.nn.functional.avg_pool1d is not callable", that's not true. I have run lintrunner locally without wrong, don't known why it still blame.

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms Pylint says "torch.nn.functional.avg_pool1d is not callable", that's not true. I have run lintrunner locally without wrong, don't known why it still blame.

You can go ahead and disable it: To disable, use # pylint: disable=not-callable

@justinchuby justinchuby added this to the 0.5.5 milestone Oct 25, 2025
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@justinchuby justinchuby merged commit 04a9da4 into microsoft:main Oct 27, 2025
30 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

5 participants