3
3
# Licensed under the MIT License.
4
4
# --------------------------------------------------------------------------
5
5
6
- from typing import Optional , Sequence , TypeVar , Union
7
6
import unittest
7
+ from typing import Any , Optional , Sequence , TypeVar , Union
8
8
9
9
import parameterized
10
10
11
- import onnxscript .testing
12
11
import onnxscript
13
- from onnxscript import script
12
+ import onnxscript .testing
13
+ from onnxscript import FLOAT , INT64 , script , type_annotation
14
14
from onnxscript .onnx_opset import opset15 as op
15
- from onnxscript import FLOAT , INT64
16
15
from onnxscript .tests .common import testutils
17
16
18
- from onnxscript import type_annotation
19
-
20
17
21
18
class TypeAnnotationTest (testutils .TestBase ):
22
19
def test_type_annotation (self ):
@@ -99,7 +96,7 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]:
99
96
_TestTypeVarTwoBound = TypeVar ("_TestTypeVarTwoBound" , bound = Union [INT64 , FLOAT ])
100
97
101
98
102
- class UtilityFunctionsTest (unittest .TestCase ):
99
+ class TypeConversionFunctionsTest (unittest .TestCase ):
103
100
@parameterized .parameterized .expand (
104
101
[
105
102
(
@@ -121,10 +118,12 @@ class UtilityFunctionsTest(unittest.TestCase):
121
118
(
122
119
"optional_tensor_type_all" ,
123
120
Optional [onnxscript .onnx_types .TensorType ],
124
- type_annotation .ALL_TYPE_STRINGS
125
- + [
126
- f"optional({ tensor_type } )"
127
- for tensor_type in type_annotation .ALL_TYPE_STRINGS
121
+ [
122
+ * type_annotation .ALL_TYPE_STRINGS ,
123
+ * [
124
+ f"optional({ tensor_type } )"
125
+ for tensor_type in type_annotation .ALL_TYPE_STRINGS
126
+ ],
128
127
],
129
128
),
130
129
(
@@ -214,7 +213,7 @@ class UtilityFunctionsTest(unittest.TestCase):
214
213
),
215
214
]
216
215
)
217
- def test_pytype_to_input_strings (self , _ , pytype : Any , expected ):
216
+ def test_pytype_to_input_strings (self , _ , pytype : Any , expected : list [ str ] ):
218
217
self .assertEqual (type_annotation .pytype_to_input_strings (pytype ), expected )
219
218
220
219
@parameterized .parameterized .expand (
@@ -231,15 +230,15 @@ def test_pytype_to_input_strings(self, _, pytype: Any, expected):
231
230
Sequence [_TestTypeVarOneBound ],
232
231
"Sequence_TestTypeVarOneBound" ,
233
232
),
234
- ("normal_type" , INT64 , " None" ),
233
+ ("normal_type" , INT64 , None ),
235
234
("union_type" , Union [INT64 , FLOAT ], None ),
236
235
("optional_type" , Optional [INT64 ], None ),
237
236
("sequence_type" , Sequence [INT64 ], None ),
238
237
("optional_sequence_type" , Optional [Sequence [INT64 ]], None ),
239
238
("optional_union_type" , Optional [Union [INT64 , FLOAT ]], None ),
240
239
]
241
240
)
242
- def get_type_constraint_name (self , _ : str , pytype , expected ):
241
+ def get_type_constraint_name (self , _ : str , pytype : Any , expected : Optional [ str ] ):
243
242
self .assertEqual (type_annotation .get_type_constraint_name (pytype ), expected )
244
243
245
244
0 commit comments