Skip to content

FixOp(embeddign_bag) change to use nn.functional.embeding_bag function | feat(torchlib) #1067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

xiaowuhu
Copy link
Contributor

@xiaowuhu xiaowuhu commented Sep 20, 2023

For issue #1056, no idea what test case it is using and which target function is using ( nn.functional.embedding_bag or ops.aten.embedding_bag?)
This PR changed to use nn.functional.embedding_bag as reference.
The difference are:

  • Only return 1 output
  • Support when indices is 2d tensor. In this case, the 'offsets' argument can be None.
  • The padding_idx argument not support, need another overload to support.

per_sample_weights = op.Expand(1, op.Shape(indices_1d))
# Dtype of per_sample_weights is the same as weight
per_sample_weights = op.CastLike(per_sample_weights, weight)
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(

Check warning

Code scanning / lintrunner

PYLINT/W0612

Unused variable 'offset2bag' (unused-variable) See [unused-variable](https://pylint.pycqa.org/en/latest/user_guide/messages/warning/unused-variable.html). To disable, use ` # pylint: disable=unused-variable`
per_sample_weights = op.Expand(1, op.Shape(indices_1d))
# Dtype of per_sample_weights is the same as weight
per_sample_weights = op.CastLike(per_sample_weights, weight)
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(

Check warning

Code scanning / lintrunner

PYLINT/W0612

Unused variable 'bag_size' (unused-variable) See [unused-variable](https://pylint.pycqa.org/en/latest/user_guide/messages/warning/unused-variable.html). To disable, use ` # pylint: disable=unused-variable`
per_sample_weights = op.Expand(1, op.Shape(indices_1d))
# Dtype of per_sample_weights is the same as weight
per_sample_weights = op.CastLike(per_sample_weights, weight)
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(

Check warning

Code scanning / lintrunner

PYLINT/W0612

Unused variable 'max_indices' (unused-variable) See [unused-variable](https://pylint.pycqa.org/en/latest/user_guide/messages/warning/unused-variable.html). To disable, use ` # pylint: disable=unused-variable`
Comment on lines +2661 to +2667
# if per_sample_weights is None:
# # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script
# # Size of persample_weights is the same as indices, and should be 1d tensor
# indices_1d = op.Reshape(indices, [-1])
# per_sample_weights = op.Expand(1, op.Shape(indices_1d))
# # Dtype of per_sample_weights is the same as weight
# per_sample_weights = op.CastLike(per_sample_weights, weight)

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
@codecov
Copy link

codecov bot commented Sep 20, 2023

Codecov Report

Merging #1067 (1121b4d) into main (b6c5e3b) will decrease coverage by 0.18%.
The diff coverage is 87.09%.

@@            Coverage Diff             @@
##             main    #1067      +/-   ##
==========================================
- Coverage   77.94%   77.76%   -0.18%     
==========================================
  Files         115      115              
  Lines       14684    14706      +22     
  Branches     1558     1563       +5     
==========================================
- Hits        11445    11436       -9     
- Misses       2871     2901      +30     
- Partials      368      369       +1     
Files Changed Coverage Δ
...ript/tests/function_libs/torch_lib/extra_opinfo.py 88.09% <0.00%> (-10.21%) ⬇️
...ipt/tests/function_libs/torch_lib/ops_test_data.py 95.89% <85.71%> (-0.28%) ⬇️
onnxscript/function_libs/torch_lib/ops/core.py 79.80% <100.00%> (+0.10%) ⬆️

@github-actions
Copy link

github-actions bot commented Sep 20, 2023

Test Results

         18 files  ±         0         18 suites  ±0   1h 13m 51s ⏱️ + 6m 21s
  10 842 tests +         1    8 216 ✔️  -        4      2 617 💤 +         4         8 ±0  1 🔥 +1 
170 728 runs  +16 438  38 651 ✔️ +3 831  130 720 💤 +12 606  1 356 ±0  1 🔥 +1 

For more details on these failures and errors, see this check.

Results for commit 1121b4d. ± Comparison against base commit b6c5e3b.

This pull request removes 4 and adds 5 tests. Note that renamed tests count towards both.
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__ops_aten_embedding_bag_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__ops_aten_embedding_bag_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__ops_aten_embedding_bag_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__ops_aten_embedding_bag_cpu_float32
onnxscript.function_libs.tools.torch_lib.deduce_type_constraints_test.TestDeduceTypeConstraints ‑ test_deduce_type_constraints_does_not_crash_for_onnx_function__aten_embedding_bag_2d_onnx
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__nn_functional_embedding_bag_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__nn_functional_embedding_bag_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__nn_functional_embedding_bag_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__nn_functional_embedding_bag_cpu_float32

♻️ This comment has been updated with latest results.

@justinchuby
Copy link
Collaborator

justinchuby commented Sep 20, 2023

We should still return four outputs and test using the aten op as explained in teams

@justinchuby justinchuby self-assigned this Sep 25, 2023
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking for now. Thanks!

@xiaowuhu
Copy link
Contributor Author

Not this guess, try another way.

@justinchuby justinchuby deleted the xiaowu/FixOp(embedding_bag_fun) branch January 27, 2025 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants