Skip to content

Simplify model error message in test | test(torchlib) #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@
# TODO(justinchuby): Build a context manager to handle source information.


def _rename_intermediate_value(name: str) -> str:
if name.isdigit():
return f"_val_{name}"
return name


def _rename_intermediate_constant(name: str) -> str:
if name.isdigit():
return f"_const_{name}"
return name


class TorchScriptTensor(onnxscript_tensor.Tensor):
"""A onnxscript tensor that wraps a torchscript Value."""

Expand Down Expand Up @@ -454,6 +466,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
self._torch_graph, "prim::Constant", inputs=(), attributes={}
)[0]
value.setType(torch.OptionalType.ofTensor())
value.setDebugName(_rename_intermediate_constant(value.debugName()))
return value

if isinstance(constant, bool):
Expand All @@ -475,12 +488,14 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
raise TypeError(
f"Constant input '{constant}' of type '{type(constant)}' is not supported"
)
return _create_op_call_in_torch_graph(
value = _create_op_call_in_torch_graph(
self._torch_graph,
"onnx::Constant",
inputs=(),
attributes=dict(value=constant_tensor),
)[0]
value.setDebugName(_rename_intermediate_constant(value.debugName()))
return value

@runtime_typing.checked
def _add_torchscript_op_call(
Expand Down Expand Up @@ -524,9 +539,15 @@ def _add_torchscript_op_call(
attributes=onnx_attributes,
n_outputs=n_outputs,
)
if len(result) <= 1:
return TorchScriptTensor(result[0])
return tuple(TorchScriptTensor(v) for v in result)
assert result, "Expected at least one output from ONNX op call."
if len(result) == 1:
tensor = TorchScriptTensor(result[0])
tensor.name = _rename_intermediate_value(tensor.name)
return tensor
tensors = tuple(TorchScriptTensor(v) for v in result)
for tensor in tensors:
tensor.name = _rename_intermediate_value(tensor.name)
return tensors

@runtime_typing.checked
def fetch_function_proto_dict(
Expand Down
14 changes: 1 addition & 13 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import unittest
from typing import Any, Callable, Optional, Sequence, Tuple
from typing import Callable, Optional, Sequence, Tuple

import numpy as np
import onnx
Expand Down Expand Up @@ -72,18 +72,6 @@ def _should_skip_xfail_test_sample(
return None, None


def _split_function_and_wrangler(
onnx_function_and_wrangler: Callable[..., Any]
| tuple[Callable[..., Any], Callable[..., Any]]
) -> tuple[Callable[..., Any], Callable[..., Any] | None]:
"""Splits a function with an optional input wrangler into a function and an input wrangler."""
if isinstance(onnx_function_and_wrangler, tuple):
return onnx_function_and_wrangler

assert callable(onnx_function_and_wrangler)
return onnx_function_and_wrangler, None


class TestFunctionValidity(unittest.TestCase):
def test_all_script_functions_are_onnx_functions(self):
for info in ops_test_data.TESTED_TORCHLIB_OPS:
Expand Down
4 changes: 1 addition & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
onnx.checker.check_model(onnx_model, full_check=True)
except onnx.checker.ValidationError as e:
raise AssertionError(
f"ONNX model is invalid: {e}. "
f"Model:\n"
f"{onnxscript.proto2text(onnx_model)}"
f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}"
) from e

try:
Expand Down