-
Notifications
You must be signed in to change notification settings - Fork 107
Add a new op embedding. #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
479a3d1
d01e591
df2f9ea
c30e0c8
f80920b
78f075d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1511,16 +1511,15 @@ def aten_einsum( | |||||
| raise NotImplementedError() | ||||||
|
|
||||||
|
|
||||||
| @torch_op("aten::embedding") | ||||||
| def aten_embedding( | ||||||
| weight: TensorType, | ||||||
| indices: TensorType, | ||||||
| padding_idx: int = -1, | ||||||
| scale_grad_by_freq: bool = False, | ||||||
| sparse: bool = False, | ||||||
| ) -> TensorType: | ||||||
| weight: TTensor, | ||||||
| indices: TTensor, | ||||||
| **kwargs,# pylint: disable=unused-argument | ||||||
|
||||||
| **kwargs,# pylint: disable=unused-argument | |
| **_, |
Don’t know if ** works in onnxscript, but if it does:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to change behavior based on padding_idx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to current implementation, it will only impact Training. I think ONNX Script won't know if current export is under Training mode, so the warning about possible wrong training result should be thrown by exporter, not onnx-script.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should also support any training behaviors when possible.
Uh oh!
There was an error while loading. Please reload this page.