-
Notifications
You must be signed in to change notification settings - Fork 81
Added padding_idx=None option and new test cases for aten_embedding_bag #2549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2549 +/- ##
==========================================
+ Coverage 70.00% 70.33% +0.33%
==========================================
Files 215 218 +3
Lines 25992 26430 +438
Branches 2606 2647 +41
==========================================
+ Hits 18196 18590 +394
- Misses 6896 6936 +40
- Partials 900 904 +4 ☔ View full report in Codecov by Sentry. |
@justinchuby I can handle the linting issues, but I’m confused about the other CI failures — could you help? |
opinfo_core.OpInfo( | ||
"test_embedding_bag_with_padding_idx_int", | ||
op=torch.nn.functional.embedding_bag, | ||
dtypes=(torch.float32,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it’s this line: you need the all_float_types() etc. construct for specifying supported dtypes. See other existing op infos for reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to common_dtype.floating_types_and_half()
, similar to other test cases
Fix Issues #2219, #2385 and the first part of #2489
This commit adds new test cases and the necessary implementation changes to correctly support the
padding_idx=None
option in theaten_embedding_bag
operator. This aligns the ONNX Script operator with PyTorch's native behavior and expands test coverage for this feature.Key Changes:
core.py
: Theaten_embedding_bag_padding_idx
function has been updated to handlepadding_idx=None
. This new code routes the operation to the standardaten_embedding_bag
implementation when no padding indices are specified.extra_opinfo.py
: Two newOpInfo
definitions,test_embedding_bag_with_padding_idx_none
andtest_embedding_bag_with_padding_idx_int
, have been added to theOP_DB
list. These provide input samples to test the new and existingpadding_idx
functionality.ops_test_data.py
: TheTESTED_TORCHLIB_OPS
tuple has been updated to include the new tests, ensuring they are discovered and executed by the test runner.