-
Notifications
You must be signed in to change notification settings - Fork 63
FixOp(embeddign_bag) change to use nn.functional.embeding_bag function | feat(torchlib) #1067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
per_sample_weights = op.Expand(1, op.Shape(indices_1d)) | ||
# Dtype of per_sample_weights is the same as weight | ||
per_sample_weights = op.CastLike(per_sample_weights, weight) | ||
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( |
Check warning
Code scanning / lintrunner
PYLINT/W0612
per_sample_weights = op.Expand(1, op.Shape(indices_1d)) | ||
# Dtype of per_sample_weights is the same as weight | ||
per_sample_weights = op.CastLike(per_sample_weights, weight) | ||
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( |
Check warning
Code scanning / lintrunner
PYLINT/W0612
per_sample_weights = op.Expand(1, op.Shape(indices_1d)) | ||
# Dtype of per_sample_weights is the same as weight | ||
per_sample_weights = op.CastLike(per_sample_weights, weight) | ||
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( |
Check warning
Code scanning / lintrunner
PYLINT/W0612
# if per_sample_weights is None: | ||
# # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script | ||
# # Size of persample_weights is the same as indices, and should be 1d tensor | ||
# indices_1d = op.Reshape(indices, [-1]) | ||
# per_sample_weights = op.Expand(1, op.Shape(indices_1d)) | ||
# # Dtype of per_sample_weights is the same as weight | ||
# per_sample_weights = op.CastLike(per_sample_weights, weight) |
Check notice
Code scanning / CodeQL
Commented-out code
Codecov Report
@@ Coverage Diff @@
## main #1067 +/- ##
==========================================
- Coverage 77.94% 77.76% -0.18%
==========================================
Files 115 115
Lines 14684 14706 +22
Branches 1558 1563 +5
==========================================
- Hits 11445 11436 -9
- Misses 2871 2901 +30
- Partials 368 369 +1
|
Test Results 18 files ± 0 18 suites ±0 1h 13m 51s ⏱️ + 6m 21s For more details on these failures and errors, see this check. Results for commit 1121b4d. ± Comparison against base commit b6c5e3b. This pull request removes 4 and adds 5 tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
We should still return four outputs and test using the aten op as explained in teams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking for now. Thanks!
Not this guess, try another way. |
For issue #1056, no idea what test case it is using and which target function is using ( nn.functional.embedding_bag or ops.aten.embedding_bag?)
This PR changed to use nn.functional.embedding_bag as reference.
The difference are: