Skip to content

Commit f707360

Browse files
committed
test
1 parent 2d0923a commit f707360

2 files changed

Lines changed: 5 additions & 10 deletions

File tree

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import math
1515
from typing import Any, Optional, Sequence, Tuple, Union
1616

17-
import numpy as np
18-
1917
from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph
2018
from onnxscript.function_libs.torch_lib.registration import torch_op
2119
from onnxscript.function_libs.torch_lib.tensor_typing import (

onnxscript/tests/function_libs/torch_lib/ops_test_common.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any:
261261
if isinstance(input, (tuple, list)):
262262
if len(input) == 0:
263263
return np.array((), dtype=np.int64)
264-
if isinstance(input[0], torch.Tensor):
264+
if any(isinstance(x, torch.Tensor) for x in input):
265+
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
265266
return [convert_tensor_to_numpy(x) for x in input]
266267
if isinstance(input[0], bool):
267268
return np.array(input, dtype=np.bool_)
@@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any:
276277

277278

278279
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
279-
"""Converts kwargs to be compatible with ONNX Runtime.
280-
281-
ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
282-
"""
280+
"""Converts kwargs to be compatible with ONNX Runtime."""
283281
new_kwargs = {}
284282
for key, value in kwargs.items():
285283
if key == "device":
@@ -515,10 +513,9 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
515513
# Make sure the model is valid
516514
try:
517515
onnx.checker.check_model(onnx_model, full_check=True)
518-
except onnx.checker.ValidationError as e:
516+
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
519517
raise AssertionError(
520-
f"ONNX model is invalid: {e}. "
521-
f"Model:\n"
518+
f"ONNX model is invalid, Model:\n"
522519
f"{onnxscript.proto2text(onnx_model)}"
523520
) from e
524521

0 commit comments

Comments
 (0)