Skip to content

fix(atenlib): test split and support Sequence input #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion onnxscript/function_libs/torch_aten/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,19 @@ def _add_torchscript_op_call(
graph_inputs = []
assert isinstance(unwrapped_inputs, Sequence)
for input in unwrapped_inputs:
if not isinstance(input, torch.Value):
if isinstance(input, Sequence) and all(
isinstance(elem, torch.Value) for elem in input
):
# If all elements in the Sequence are torch.Values we know it
# should be a Sequence input in ONNX.
input_sequence = _create_op_call_in_torch_graph(
self._torch_graph,
"onnx::SequenceConstruct",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make more sense to move this to PyTorch side, like we did on op.Concat? We should process all the list related stuff in the same place. Especially, fx.node might not always be tensors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is different because we annotate the type as Sequence[Tensor], whereas for the Concat case we annotated as INT64. We can chat more

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably create make sure everything is a tensor on the exporter side first, but I am open to suggestions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer moving it out to PyTorch. It would be more consistent in the function _retrieve_or_adapt_input_to_graph_set in a sense that we intend to process all Python Sequence there, which it would be obvious for users/developers, for example, sequence with node(symint), we concat it, but if it's all tensors, we use SequenceConstruct, etc. We would be able to put all scenarios in the same place. On the other hand, in symbolic_fn annotations we specifically mean onnx_tpye.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the current type system we use shows that the functions take a (python) Sequence. If we do SequenceConstruct outside, the input will not be a Sequence. However, we can also extend the return value of SequenceConstruct to be a proper Sequence so that we can do what you suggested

inputs=input,
attributes={},
)[0]
graph_inputs.append(input_sequence)
elif not isinstance(input, torch.Value):
graph_inputs.append(self._add_constant_to_graph(input))
else:
graph_inputs.append(input)
Expand Down
43 changes: 17 additions & 26 deletions onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,6 @@ def _where_input_wrangler(
reason="fixme: ORT shape inference error",
test_class_name="TestOutputConsistency_FullGraph",
),
xfail(
"cat",
reason="fixme: TorchScriptEvaluator does not support TensorSequence. Enable after #484",
test_class_name="TestOutputConsistency_FullGraph",
),
xfail(
"chunk", reason="fixme: ORT error", test_class_name="TestOutputConsistency_FullGraph"
),
Expand Down Expand Up @@ -624,25 +619,6 @@ def _where_input_wrangler(
xfail(
"round", variant_name="decimals_neg_3", reason="The op does not support decimals yet"
),
xfail(
"split",
reason="fixme: split produces a Sequence type but is set incorrectly in this test",
test_class_name="TestOutputConsistency_FullGraph",
),
xfail(
"split",
variant_name="list_args",
reason="fixme: split produces a Sequence type but is set incorrectly in this test",
test_class_name="TestOutputConsistency_FullGraph",
),
xfail(
"split_with_sizes",
reason="fixme: split produces a Sequence type but is set incorrectly in this test",
test_class_name="TestOutputConsistency_FullGraph",
),
xfail(
"stack", reason="enable after #484", test_class_name="TestOutputConsistency_FullGraph"
),
xfail(
"t",
reason="ORT Graph attribute inferencing failed on rank-1 input",
Expand Down Expand Up @@ -1013,8 +989,14 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs)
if not isinstance(symbolic_outputs, Sequence):
symbolic_outputs = (symbolic_outputs,)
# We need to set the size of the output tensors for the model to be valid

# We need to set the size of the output tensors for the ONNX model to be valid
for output, symbolic_output in zip(outputs, symbolic_outputs):
if isinstance(output, Sequence):
# Output is a sequence, set the type correctly to ListType
symbolic_output.dtype = output[0].dtype
symbolic_output.symbolic_value().setType(torch.ListType.ofTensors())
continue
output = (
output
if isinstance(output, torch.Tensor)
Expand Down Expand Up @@ -1148,11 +1130,20 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)

# TODO: add pytree structure comparison.
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
if op.name.startswith("split"):
# Hack for handling split
# Split returns a Sequence that should be treats as a single
# value. So we wrap it into a tuple.
# TODO(justinchuby): Find a more general solution
flattened_torch_outputs = (flattened_torch_outputs,)

function_output = self.function_executor(flattened_torch_outputs)(
onnx_function, input_onnx, kwargs_onnx
)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
flattened_function_outputs, _ = pytree.tree_flatten(function_output)

assert flattened_torch_outputs
Expand Down