@@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any:
261261 if isinstance (input , (tuple , list )):
262262 if len (input ) == 0 :
263263 return np .array ((), dtype = np .int64 )
264- if isinstance (input [0 ], torch .Tensor ):
264+ if any (isinstance (x , torch .Tensor ) for x in input ):
265+ # The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
265266 return [convert_tensor_to_numpy (x ) for x in input ]
266267 if isinstance (input [0 ], bool ):
267268 return np .array (input , dtype = np .bool_ )
@@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any:
276277
277278
278279def convert_kwargs_for_onnx (kwargs : dict [str , Any ]) -> dict [str , Any ]:
279- """Converts kwargs to be compatible with ONNX Runtime.
280-
281- ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
282- """
280+ """Converts kwargs to be compatible with ONNX Runtime."""
283281 new_kwargs = {}
284282 for key , value in kwargs .items ():
285283 if key == "device" :
@@ -515,10 +513,9 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
515513 # Make sure the model is valid
516514 try :
517515 onnx .checker .check_model (onnx_model , full_check = True )
518- except onnx .checker .ValidationError as e :
516+ except ( onnx .checker .ValidationError , onnx . shape_inference . InferenceError ) as e :
519517 raise AssertionError (
520- f"ONNX model is invalid: { e } . "
521- f"Model:\n "
518+ f"ONNX model is invalid, Model:\n "
522519 f"{ onnxscript .proto2text (onnx_model )} "
523520 ) from e
524521
0 commit comments