-
Notifications
You must be signed in to change notification settings - Fork 107
feat(atenlib): add ops (new_empty, new_empty_strided) #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
173d8dd
Update core.py
xiaowuhu f50cbdd
add new ops
xiaowuhu a09fc77
Update core.py
xiaowuhu 1fad593
Merge branch 'main' into xiaowu/addOps(0214)
xiaowuhu 1366519
fix comment
xiaowuhu 9e73133
Update core.py
xiaowuhu d555ff1
Update core.py
xiaowuhu 01fd43f
Update core.py
xiaowuhu eaf6ce3
Update core.py
xiaowuhu 0c98ee8
Merge branch 'main' into xiaowu/addOps(0214)
xiaowuhu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3930,16 +3930,36 @@ def aten_negative(self: TensorType) -> TensorType: | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_new_empty(self: TensorType, size: INT64) -> TensorType: | ||
| @torch_op("aten::new_empty") | ||
| def aten_new_empty(self: TTensor, size: INT64, dtype: int = -1) -> TTensor: | ||
| # new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| # using zero to simulate empty array | ||
| zero = op.Constant(value_float=0.0) | ||
| result = op.Expand(zero, size) | ||
| if dtype == -1: | ||
| result = op.CastLike(result, self) | ||
| else: | ||
| result = op.Cast(result, to=dtype) | ||
| return result | ||
|
|
||
|
|
||
| def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> TensorType: | ||
| @torch_op("aten::new_empty_strided") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially trace only |
||
| def aten_new_empty_strided( | ||
| self: TTensor, | ||
| size: INT64, | ||
| stride: INT64, # pylint: disable=unused-argument | ||
| dtype: int = -1, | ||
| ) -> TTensor: | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| # using zero to simulate empty array | ||
| zero = op.ConstantOfShape(size) | ||
| if dtype == -1: | ||
| result = op.CastLike(zero, self) | ||
| else: | ||
| result = op.Cast(zero, to=dtype) | ||
| return result | ||
|
|
||
|
|
||
| @torch_op("aten::new_full") | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -302,6 +302,8 @@ def _where_input_wrangler( | |
| "mul": core_ops.aten_mul, | ||
| "ne": core_ops.aten_ne, | ||
| "neg": core_ops.aten_neg, | ||
| "new_empty": core_ops.aten_new_empty, | ||
| "new_empty_strided": core_ops.aten_new_empty_strided, | ||
| "new_full": core_ops.aten_new_full, | ||
| "nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d, | ||
| "nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d, | ||
|
|
@@ -396,6 +398,8 @@ def _where_input_wrangler( | |
| skip("empty_like", reason="Using zeros_like to simulate empty_like"), | ||
| xfail("logcumsumexp", reason="naive implementation not numerically stable"), | ||
| xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"), | ||
| xfail("new_empty", reason="Using zeros to simulate empty"), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xiaowuhu I also realized this can succeed unexpectedly. In that case a skip may be a better choice than xfail. |
||
| xfail("new_empty_strided", reason="Using zeros to simulate empty"), | ||
| xfail( | ||
| "nn.functional.upsample_nearest2d", | ||
| reason="enable when ONNX Runtime does support opset18", | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype needs to be the same for both branches. this op may need to be trace only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not got your point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So ONNX requires that both branches of an If node’s type to be the same. Since we are casting here the graph violates this constraint. I think Rama is looking at potential solutions, but for now we will need to mark the function traceonly, or else the graph would be invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More info in onnx/onnx#4872