-
Notifications
You must be signed in to change notification settings - Fork 64
fix(atenlib): test split and support Sequence input #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #528 +/- ##
=======================================
Coverage 73.70% 73.71%
=======================================
Files 109 109
Lines 10863 10873 +10
Branches 1128 1131 +3
=======================================
+ Hits 8007 8015 +8
- Misses 2550 2553 +3
+ Partials 306 305 -1
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving a comments. others LGTM. Thanks!
Should I make split trace only instead? It seems suitable with its dynamic output nature. But then we would be able to support the symint. To me that seems ok |
No because when there is |
I will merge to unblock other prs and create an issue for discussions. |
# should be a Sequence input in ONNX. | ||
input_sequence = _create_op_call_in_torch_graph( | ||
self._torch_graph, | ||
"onnx::SequenceConstruct", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make more sense to move this to PyTorch side, like we did on op.Concat? We should process all the list related stuff in the same place. Especially, fx.node might not always be tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is different because we annotate the type as Sequence[Tensor]
, whereas for the Concat case we annotated as INT64
. We can chat more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably create make sure everything is a tensor on the exporter side first, but I am open to suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I prefer moving it out to PyTorch. It would be more consistent in the function _retrieve_or_adapt_input_to_graph_set
in a sense that we intend to process all Python Sequence there, which it would be obvious for users/developers, for example, sequence with node(symint), we concat it, but if it's all tensors, we use SequenceConstruct, etc. We would be able to put all scenarios in the same place. On the other hand, in symbolic_fn annotations we specifically mean onnx_tpye.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the current type system we use shows that the functions take a (python) Sequence. If we do SequenceConstruct outside, the input will not be a Sequence. However, we can also extend the return value of SequenceConstruct to be a proper Sequence so that we can do what you suggested
This PR adjusts the test code to be able to handle the output of
torch.split
and adds support for sequence as input. Note that this is different from the exporter's concat for symints, because at the point when we have inputs processed and still get list of Tensors as a single argument, we know it should be an ONNX Sequence tensor in the graph.We also need to keep the addition in here and not in the exporter to match behavior with onnxscript eager: we accept a list of tensors as input when the op takes a Sequence.