Skip to content

fix(atenlib): test split and support Sequence input #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 16, 2023

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Mar 16, 2023

This PR adjusts the test code to be able to handle the output of torch.split and adds support for sequence as input. Note that this is different from the exporter's concat for symints, because at the point when we have inputs processed and still get list of Tensors as a single argument, we know it should be an ONNX Sequence tensor in the graph.

We also need to keep the addition in here and not in the exporter to match behavior with onnxscript eager: we accept a list of tensors as input when the op takes a Sequence.

This was referenced Mar 16, 2023
@codecov
Copy link

codecov bot commented Mar 16, 2023

Codecov Report

Merging #528 (241f44f) into main (ac77a19) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main     #528   +/-   ##
=======================================
  Coverage   73.70%   73.71%           
=======================================
  Files         109      109           
  Lines       10863    10873   +10     
  Branches     1128     1131    +3     
=======================================
+ Hits         8007     8015    +8     
- Misses       2550     2553    +3     
+ Partials      306      305    -1     
Impacted Files Coverage Δ
...xscript/function_libs/torch_aten/graph_building.py 79.11% <100.00%> (-0.56%) ⬇️
...s/function_libs/torch_aten/ops_correctness_test.py 90.32% <100.00%> (+0.20%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Copy link
Contributor

@xiaowuhu xiaowuhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a comments. others LGTM. Thanks!

@justinchuby
Copy link
Collaborator Author

Should I make split trace only instead? It seems suitable with its dynamic output nature. But then we would be able to support the symint. To me that seems ok

@justinchuby
Copy link
Collaborator Author

Should I make split trace only instead? It seems suitable with its dynamic output nature. But then we would be able to support the symint. To me that seems ok

No because when there is sizes the output number is always dynamic. I wonder what the traced graph from fx looks like? @BowenBao

@justinchuby
Copy link
Collaborator Author

I will merge to unblock other prs and create an issue for discussions.

# should be a Sequence input in ONNX.
input_sequence = _create_op_call_in_torch_graph(
self._torch_graph,
"onnx::SequenceConstruct",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make more sense to move this to PyTorch side, like we did on op.Concat? We should process all the list related stuff in the same place. Especially, fx.node might not always be tensors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is different because we annotate the type as Sequence[Tensor], whereas for the Concat case we annotated as INT64. We can chat more

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably create make sure everything is a tensor on the exporter side first, but I am open to suggestions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer moving it out to PyTorch. It would be more consistent in the function _retrieve_or_adapt_input_to_graph_set in a sense that we intend to process all Python Sequence there, which it would be obvious for users/developers, for example, sequence with node(symint), we concat it, but if it's all tensors, we use SequenceConstruct, etc. We would be able to put all scenarios in the same place. On the other hand, in symbolic_fn annotations we specifically mean onnx_tpye.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the current type system we use shows that the functions take a (python) Sequence. If we do SequenceConstruct outside, the input will not be a Sequence. However, we can also extend the return value of SequenceConstruct to be a proper Sequence so that we can do what you suggested

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants