Skip to content

Commit e55a1c6

Browse files
authored
[IR] Fix sequence handling in tensor function (#2252)
(Copilot) Fix bug in `tensor()` function to handle empty sequences and require `dtype` when value is an empty sequence. * Add a check to ensure the sequence is non-empty before performing type checks in the `tensor()` function in `onnxscript/ir/_convenience/_constructors.py`. * Raise a `ValueError` if `dtype` is `None` and `value` is an empty sequence in the `tensor()` function. * Update the `tensor()` function to handle the case when a sequence is empty explicitly. * Add a test case to check if `tensor()` raises a `ValueError` when `dtype` is `None` and `value` is an empty sequence in `onnxscript/ir/_convenience/_constructors_test.py`. * Add a test case to check if `tensor()` handles the case when a sequence is empty explicitly. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/microsoft/onnxscript/pull/2252?shareId=2ab5ada5-c6bd-4bc8-be2d-e9357dcbaa7b).
1 parent a78bf43 commit e55a1c6

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

onnxscript/ir/_convenience/_constructors.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,31 @@ def tensor(
9292
return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
9393
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
9494
return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string)
95-
# Plain Python object
95+
96+
# Plain (numerical) Python object. Determine the numpy dtype and use np.array to construct the tensor
9697
if dtype is not None:
98+
if not isinstance(dtype, _enums.DataType):
99+
raise TypeError(f"dtype must be an instance of DataType. dtype={dtype}")
97100
numpy_dtype = dtype.numpy()
101+
elif isinstance(value, Sequence) and not value:
102+
raise ValueError("dtype must be specified when value is an empty sequence.")
98103
elif isinstance(value, int) and not isinstance(value, bool):
99104
# Specify int64 for ints because on Windows this may be int32
100105
numpy_dtype = np.dtype(np.int64)
101106
elif isinstance(value, float):
102107
# If the value is a single float, we use np.float32 as the default dtype
103108
numpy_dtype = np.dtype(np.float32)
104-
elif isinstance(value, Sequence) and all(
105-
(isinstance(elem, int) and not isinstance(value, bool)) for elem in value
106-
):
107-
numpy_dtype = np.dtype(np.int64)
108-
elif isinstance(value, Sequence) and all(isinstance(elem, float) for elem in value):
109-
# If the value is a sequence of floats, we use np.float32 as the default dtype
110-
numpy_dtype = np.dtype(np.float32)
109+
elif isinstance(value, Sequence) and value:
110+
if all((isinstance(elem, int) and not isinstance(elem, bool)) for elem in value):
111+
numpy_dtype = np.dtype(np.int64)
112+
elif all(isinstance(elem, float) for elem in value):
113+
# If the value is a sequence of floats, we use np.float32 as the default dtype
114+
numpy_dtype = np.dtype(np.float32)
115+
else:
116+
numpy_dtype = None
111117
else:
112118
numpy_dtype = None
119+
113120
array = np.array(value, dtype=numpy_dtype)
114121

115122
# Handle string tensors by encoding them

onnxscript/ir/_convenience/_constructors_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88

9+
from onnxscript import ir
910
from onnxscript.ir._convenience import _constructors
1011

1112

@@ -17,6 +18,14 @@ def test_tensor_accepts_torch_tensor(self):
1718
tensor = _constructors.tensor(torch_tensor)
1819
np.testing.assert_array_equal(tensor, torch_tensor.numpy())
1920

21+
def test_tensor_raises_value_error_for_empty_sequence_without_dtype(self):
22+
with self.assertRaises(ValueError):
23+
_constructors.tensor([])
24+
25+
def test_tensor_handles_empty_sequence_with_dtype(self):
26+
tensor = _constructors.tensor([], dtype=ir.DataType.FLOAT)
27+
np.testing.assert_array_equal(tensor.numpy(), np.array([], dtype=np.float32))
28+
2029

2130
if __name__ == "__main__":
2231
unittest.main()

0 commit comments

Comments
 (0)