|
3 | 3 | # Licensed under the MIT License.
|
4 | 4 | # --------------------------------------------------------------------------
|
5 | 5 |
|
6 |
| -from typing import TypeVar, Union |
| 6 | +from typing import Optional, Sequence, TypeVar, Union |
7 | 7 | import unittest
|
8 | 8 |
|
9 | 9 | import parameterized
|
|
12 | 12 | import onnxscript
|
13 | 13 | from onnxscript import script
|
14 | 14 | from onnxscript.onnx_opset import opset15 as op
|
15 |
| -from onnxscript.onnx_types import FLOAT |
| 15 | +from onnxscript import FLOAT, INT64 |
16 | 16 | from onnxscript.tests.common import testutils
|
17 | 17 |
|
18 | 18 | from onnxscript import type_annotation
|
@@ -93,28 +93,155 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]:
|
93 | 93 | bool_type_for_attribute, bool_type_for_attribute_txt
|
94 | 94 | )
|
95 | 95 |
|
96 |
| -_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", onnxscript.INT64, onnxscript.FLOAT) |
97 |
| -_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=onnxscript.INT64) |
98 |
| -_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[onnxscript.INT64, onnxscript.FLOAT]) |
| 96 | + |
| 97 | +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) |
| 98 | +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) |
| 99 | +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) |
99 | 100 |
|
100 | 101 |
|
101 | 102 | class UtilityFunctionsTest(unittest.TestCase):
|
102 | 103 | @parameterized.parameterized.expand(
|
103 | 104 | [
|
104 |
| - ("tensor_type", onnxscript.onnx_types.TensorType, type_annotation.ALL_TYPE_STRINGS), |
105 |
| - ("tensor_type", onnxscript.INT64, ["tensor(int64)"]), |
106 |
| - ("tensor_type_variadic_shape", onnxscript.INT64[...], ["tensor(int64)"]), |
107 |
| - ("tensor_type_shape", onnxscript.INT64[10], ["tensor(int64)"]), |
108 |
| - ("type_var_constraints", _TestTypeVarConstraints, ["tensor(int64)"]), |
| 105 | + ( |
| 106 | + "tensor_type_all", |
| 107 | + onnxscript.onnx_types.TensorType, |
| 108 | + type_annotation.ALL_TYPE_STRINGS, |
| 109 | + ), |
| 110 | + ("tensor_type", INT64, ["tensor(int64)"]), |
| 111 | + ("tensor_type_union", Union[INT64, FLOAT], ["tensor(int64)", "tensor(float)"]), |
| 112 | + ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), |
| 113 | + ("tensor_type_shape", INT64[10], ["tensor(int64)"]), |
| 114 | + ( |
| 115 | + "type_var_constraints", |
| 116 | + _TestTypeVarConstraints, |
| 117 | + ["tensor(int64)", "tensor(float)"], |
| 118 | + ), |
109 | 119 | ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]),
|
110 | 120 | ("type_bound_two", _TestTypeVarTwoBound, ["tensor(int64)", "tensor(float)"]),
|
| 121 | + ( |
| 122 | + "optional_tensor_type_all", |
| 123 | + Optional[onnxscript.onnx_types.TensorType], |
| 124 | + type_annotation.ALL_TYPE_STRINGS |
| 125 | + + [ |
| 126 | + f"optional({tensor_type})" |
| 127 | + for tensor_type in type_annotation.ALL_TYPE_STRINGS |
| 128 | + ], |
| 129 | + ), |
| 130 | + ( |
| 131 | + "optional_tensor_type", |
| 132 | + Optional[INT64], |
| 133 | + ["tensor(int64)", "optional(tensor(int64))"], |
| 134 | + ), |
| 135 | + ( |
| 136 | + "optional_tensor_type_union", |
| 137 | + Optional[Union[INT64, FLOAT]], |
| 138 | + [ |
| 139 | + "tensor(int64)", |
| 140 | + "tensor(float)", |
| 141 | + "optional(tensor(int64))", |
| 142 | + "optional(tensor(float))", |
| 143 | + ], |
| 144 | + ), |
| 145 | + ( |
| 146 | + "optional_tensor_type_variadic_shape", |
| 147 | + Optional[INT64[...]], |
| 148 | + ["tensor(int64)", "optional(tensor(int64))"], |
| 149 | + ), |
| 150 | + ( |
| 151 | + "optional_tensor_type_shape", |
| 152 | + Optional[INT64[10]], |
| 153 | + ["tensor(int64)", "optional(tensor(int64))"], |
| 154 | + ), |
| 155 | + ( |
| 156 | + "optional_type_var_constraints", |
| 157 | + Optional[_TestTypeVarConstraints], |
| 158 | + [ |
| 159 | + "tensor(int64)", |
| 160 | + "tensor(float)", |
| 161 | + "optional(tensor(int64))", |
| 162 | + "optional(tensor(float))", |
| 163 | + ], |
| 164 | + ), |
| 165 | + ( |
| 166 | + "optional_type_bound_one", |
| 167 | + Optional[_TestTypeVarOneBound], |
| 168 | + ["tensor(int64)", "optional(tensor(int64))"], |
| 169 | + ), |
| 170 | + ( |
| 171 | + "optional_type_bound_two", |
| 172 | + Optional[_TestTypeVarTwoBound], |
| 173 | + [ |
| 174 | + "tensor(int64)", |
| 175 | + "tensor(float)", |
| 176 | + "optional(tensor(int64))", |
| 177 | + "optional(tensor(float))", |
| 178 | + ], |
| 179 | + ), |
| 180 | + ( |
| 181 | + "sequence_type_all", |
| 182 | + Sequence[onnxscript.onnx_types.TensorType], |
| 183 | + [ |
| 184 | + f"sequence({tensor_type})" |
| 185 | + for tensor_type in type_annotation.ALL_TYPE_STRINGS |
| 186 | + ], |
| 187 | + ), |
| 188 | + ("sequence_type", Sequence[INT64], ["sequence(tensor(int64))"]), |
| 189 | + ( |
| 190 | + "union_sequence_type", |
| 191 | + Union[Sequence[INT64], Sequence[FLOAT]], |
| 192 | + ["sequence(tensor(int64))", "sequence(tensor(float))"], |
| 193 | + ), |
| 194 | + ( |
| 195 | + "sequence_type_variadic_shape", |
| 196 | + Sequence[INT64[...]], |
| 197 | + ["sequence(tensor(int64))"], |
| 198 | + ), |
| 199 | + ("sequence_type_shape", Sequence[INT64[10]], ["sequence(tensor(int64))"]), |
| 200 | + ( |
| 201 | + "sequence_type_var_constraints", |
| 202 | + Sequence[_TestTypeVarConstraints], |
| 203 | + ["sequence(tensor(int64))", "sequence(tensor(float))"], |
| 204 | + ), |
| 205 | + ( |
| 206 | + "sequence_type_bound_one", |
| 207 | + Sequence[_TestTypeVarOneBound], |
| 208 | + ["sequence(tensor(int64))"], |
| 209 | + ), |
| 210 | + ( |
| 211 | + "sequence_type_bound_two", |
| 212 | + Sequence[_TestTypeVarTwoBound], |
| 213 | + ["sequence(tensor(int64))", "sequence(tensor(float))"], |
| 214 | + ), |
| 215 | + ] |
| 216 | + ) |
| 217 | + def test_pytype_to_input_strings(self, _, pytype: Any, expected): |
| 218 | + self.assertEqual(type_annotation.pytype_to_input_strings(pytype), expected) |
| 219 | + |
| 220 | + @parameterized.parameterized.expand( |
| 221 | + [ |
| 222 | + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), |
| 223 | + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), |
| 224 | + ( |
| 225 | + "optional_type_var", |
| 226 | + Optional[_TestTypeVarOneBound], |
| 227 | + "Optional_TestTypeVarOneBound", |
| 228 | + ), |
| 229 | + ( |
| 230 | + "sequence_type_var", |
| 231 | + Sequence[_TestTypeVarOneBound], |
| 232 | + "Sequence_TestTypeVarOneBound", |
| 233 | + ), |
| 234 | + ("normal_type", INT64, "None"), |
| 235 | + ("union_type", Union[INT64, FLOAT], None), |
| 236 | + ("optional_type", Optional[INT64], None), |
| 237 | + ("sequence_type", Sequence[INT64], None), |
| 238 | + ("optional_sequence_type", Optional[Sequence[INT64]], None), |
| 239 | + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), |
111 | 240 | ]
|
112 | 241 | )
|
113 |
| - def test_pytype_to_input_strings(self, _, pytype: Any, expected) |
114 |
| - pass |
| 242 | + def get_type_constraint_name(self, _: str, pytype, expected): |
| 243 | + self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) |
115 | 244 |
|
116 |
| - def get_type_constraint_name(self, _: str, typevar, expected): |
117 |
| - pass |
118 | 245 |
|
119 | 246 | if __name__ == "__main__":
|
120 | 247 | unittest.main()
|
0 commit comments