Skip to content

[torchlib] _embedding_bag #1021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
justinchuby opened this issue Aug 22, 2023 · 1 comment · Fixed by #1022
Closed

[torchlib] _embedding_bag #1021

justinchuby opened this issue Aug 22, 2023 · 1 comment · Fixed by #1022
Assignees
Labels
module: torchlib Related to the torch/aten function lib in development

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Aug 22, 2023

We need to implement _embedding_bag for these two signatures following #909.

  • func: embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
  • func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)

Notably we need to support padding_idx=-1.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Aug 22, 2023
@xiaowuhu
Copy link
Contributor

#1022

justinchuby added a commit that referenced this issue Aug 29, 2023
1. It works for common test case, say, when offsets value is common like
[0,2,3], but something like [0,2,2,4] will be failed. Guess it is due to
op.While ‘s bug.
2. Only care the shape of the last 3 outputs.
3. Test case not contain 2D indices, wait to see if it is necessary.
4. Need to disable loop/scan constrain checking

Fixes #1021

---------

Co-authored-by: Justin Chu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants