|
3 | 3 | # Licensed under the MIT License.
|
4 | 4 | # --------------------------------------------------------------------------
|
5 | 5 |
|
| 6 | +from typing import TypeVar, Union |
6 | 7 | import unittest
|
7 | 8 |
|
| 9 | +import parameterized |
| 10 | + |
8 | 11 | import onnxscript.testing
|
| 12 | +import onnxscript |
9 | 13 | from onnxscript import script
|
10 | 14 | from onnxscript.onnx_opset import opset15 as op
|
11 | 15 | from onnxscript.onnx_types import FLOAT
|
12 | 16 | from onnxscript.tests.common import testutils
|
13 | 17 |
|
| 18 | +from onnxscript import type_annotation |
| 19 | + |
14 | 20 |
|
15 | 21 | class TypeAnnotationTest(testutils.TestBase):
|
16 | 22 | def test_type_annotation(self):
|
@@ -87,9 +93,27 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]:
|
87 | 93 | bool_type_for_attribute, bool_type_for_attribute_txt
|
88 | 94 | )
|
89 | 95 |
|
| 96 | +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", onnxscript.INT64, onnxscript.FLOAT) |
| 97 | +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=onnxscript.INT64) |
| 98 | +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[onnxscript.INT64, onnxscript.FLOAT]) |
| 99 | + |
90 | 100 |
|
91 | 101 | class UtilityFunctionsTest(unittest.TestCase):
|
92 |
| - def test_pytype_to_input_strings(self): |
| 102 | + @parameterized.parameterized.expand( |
| 103 | + [ |
| 104 | + ("tensor_type", onnxscript.onnx_types.TensorType, type_annotation.ALL_TYPE_STRINGS), |
| 105 | + ("tensor_type", onnxscript.INT64, ["tensor(int64)"]), |
| 106 | + ("tensor_type_variadic_shape", onnxscript.INT64[...], ["tensor(int64)"]), |
| 107 | + ("tensor_type_shape", onnxscript.INT64[10], ["tensor(int64)"]), |
| 108 | + ("type_var_constraints", _TestTypeVarConstraints, ["tensor(int64)"]), |
| 109 | + ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), |
| 110 | + ("type_bound_two", _TestTypeVarTwoBound, ["tensor(int64)", "tensor(float)"]), |
| 111 | + ] |
| 112 | + ) |
| 113 | + def test_pytype_to_input_strings(self, _, pytype: Any, expected) |
| 114 | + pass |
| 115 | + |
| 116 | + def get_type_constraint_name(self, _: str, typevar, expected): |
93 | 117 | pass
|
94 | 118 |
|
95 | 119 | if __name__ == "__main__":
|
|
0 commit comments