2
2
# Copyright (c) Microsoft Corporation. All rights reserved.
3
3
# Licensed under the MIT License.
4
4
# --------------------------------------------------------------------------
5
+
5
6
from __future__ import annotations
6
7
7
- from typing import Optional , Tuple , Union
8
+ from abc import ABC
9
+ from typing import ClassVar , Optional , Tuple , Union
8
10
9
11
import onnx
10
12
import onnx .helper
11
13
12
- # Representations of ONNX types in ONNX Script.
13
- # Currently restricted to tensor types.
14
- # Example type annotations in ONNX Script.
15
- # x : FLOAT (a scalar-tensor of rank 0)
16
- # x : FLOAT[...] (a tensor of unknown rank)
17
- # x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
18
- # x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
14
+ DType = int
15
+
19
16
20
17
DimType = Union [int , str , type (None )]
21
18
@@ -36,131 +33,144 @@ def check_shape(shape):
36
33
check_dim (shape )
37
34
38
35
39
- class TensorType :
40
- """ONNX Script representation of a tensor type."""
36
+ tensor_type_registry : dict [DType , "TensorType" ] = {}
37
+ tensor_type_shape_cache : dict [DType , "TensorType" ] = {}
38
+
39
+
40
+ class TensorType (ABC ):
41
+ """ONNX Script representation of a tensor type supporting shape annotations.
42
+
43
+ A scalar-tensor of rank 0:
44
+ ::
45
+
46
+ tensor: FLOAT
47
+
48
+ A tensor of unknown rank:
49
+ ::
50
+
51
+ tensor: FLOAT[...]
52
+
53
+ A tensor of rank 2 of unknown dimensions, with symbolic names:
54
+ ::
55
+
56
+ tensor: FLOAT['M', 'N']
57
+
58
+ A tensor of rank 2 of known dimensions:
59
+ ::
41
60
42
- default_instance : Optional ["TensorType" ] = None
61
+ tensor: FLOAT[128, 1024]
62
+ """
63
+
64
+ dtype : ClassVar [DType ]
65
+ shape : ClassVar [Optional [ShapeType ]]
66
+
67
+ def __new__ (cls ):
68
+ raise NotImplementedError ("TensorTypes cannot be instantiated" )
43
69
44
- def __init__ (self , dtype , shape : Optional [ShapeType ] = None ) -> None :
45
- self .dtype = dtype
46
- self .shape = shape
47
- if shape is not None :
70
+ def __init_subclass__ (cls , dtype : DType , shape : Optional [ShapeType ] = None ):
71
+ cls .dtype = dtype
72
+ cls .shape = shape
73
+ if shape is None :
74
+ existing_cls = tensor_type_registry .get (dtype )
75
+ if existing_cls is not None :
76
+ raise ValueError (
77
+ f"Invalid usage: subclass { existing_cls !r} "
78
+ f"already defined for dtype={ dtype } "
79
+ )
80
+ tensor_type_registry [dtype ] = cls
81
+ else :
48
82
check_shape (shape )
49
83
50
84
def __getitem__ (self , shape : Optional [ShapeType ]):
51
- if self .shape is not None :
85
+ raise NotImplementedError ("should not be reached" )
86
+
87
+ def __class_getitem__ (cls , shape : Optional [ShapeType ]) -> TensorType :
88
+ if cls .shape is not None :
52
89
raise ValueError ("Invalid usage: shape already specified." )
53
90
if shape is None :
54
91
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
55
92
shape = (None ,)
56
- return TensorType (self .dtype , shape )
57
-
58
- def __class_getitem__ (cls , shape : Optional [ShapeType ]):
59
- if cls .default_instance is None :
60
- raise TypeError (f"{ cls } does not specify a default_instance." )
61
- # pylint erroneously flags with unsubscriptable-object if
62
- # using subscript notation (cls.default_instance[shape]):
63
- return cls .default_instance .__getitem__ (shape )
64
-
65
- def to_type_proto (self ) -> onnx .TypeProto :
66
- if self .shape is None :
93
+ key = (cls .dtype , shape )
94
+ shaped_type = tensor_type_shape_cache .get (key )
95
+ if shaped_type is None :
96
+ shaped_type = type (cls .__name__ , (TensorType ,), {}, dtype = cls .dtype , shape = shape )
97
+ tensor_type_shape_cache [key ] = shaped_type
98
+ return shaped_type
99
+
100
+ @classmethod
101
+ def to_type_proto (cls ) -> onnx .TypeProto :
102
+ if cls .shape is None :
67
103
shape = () # "FLOAT" is treated as a scalar
68
- elif self .shape is Ellipsis :
104
+ elif cls .shape is Ellipsis :
69
105
shape = None # "FLOAT[...]" is a tensor of unknown rank
70
- elif isinstance (self .shape , tuple ):
71
- shape = self .shape # example: "FLOAT[10,20]"
106
+ elif isinstance (cls .shape , tuple ):
107
+ shape = cls .shape # example: "FLOAT[10,20]"
72
108
else :
73
- shape = [self .shape ] # example: "FLOAT[10]"
74
- return onnx .helper .make_tensor_type_proto (self .dtype , shape )
75
-
76
-
77
- class _BuiltinTensorType :
78
- def __init__ (self , tensor_proto : onnx .TensorProto ):
79
- self .tensor_proto = tensor_proto
80
-
81
- def __call__ (self , cls ):
82
- cls .default_instance = TensorType (self .tensor_proto )
83
- cls .to_type_proto = cls .default_instance .to_type_proto
84
- return cls
109
+ shape = [cls .shape ] # example: "FLOAT[10]"
110
+ return onnx .helper .make_tensor_type_proto (cls .dtype , shape )
85
111
86
112
87
- @_BuiltinTensorType (onnx .TensorProto .FLOAT )
88
- class FLOAT (TensorType ):
113
+ class FLOAT (TensorType , dtype = onnx .TensorProto .FLOAT ):
89
114
pass
90
115
91
116
92
- @_BuiltinTensorType (onnx .TensorProto .UINT8 )
93
- class UINT8 (TensorType ):
117
+ class UINT8 (TensorType , dtype = onnx .TensorProto .UINT8 ):
94
118
pass
95
119
96
120
97
- @_BuiltinTensorType (onnx .TensorProto .INT8 )
98
- class INT8 (TensorType ):
121
+ class INT8 (TensorType , dtype = onnx .TensorProto .INT8 ):
99
122
pass
100
123
101
124
102
- @_BuiltinTensorType (onnx .TensorProto .UINT16 )
103
- class UINT16 (TensorType ):
125
+ class UINT16 (TensorType , dtype = onnx .TensorProto .UINT16 ):
104
126
pass
105
127
106
128
107
- @_BuiltinTensorType (onnx .TensorProto .INT16 )
108
- class INT16 (TensorType ):
129
+ class INT16 (TensorType , dtype = onnx .TensorProto .INT16 ):
109
130
pass
110
131
111
132
112
- @_BuiltinTensorType (onnx .TensorProto .INT32 )
113
- class INT32 (TensorType ):
133
+ class INT32 (TensorType , dtype = onnx .TensorProto .INT32 ):
114
134
pass
115
135
116
136
117
- @_BuiltinTensorType (onnx .TensorProto .INT64 )
118
- class INT64 (TensorType ):
137
+ class INT64 (TensorType , dtype = onnx .TensorProto .INT64 ):
119
138
pass
120
139
121
140
122
- @_BuiltinTensorType (onnx .TensorProto .STRING )
123
- class STRING (TensorType ):
141
+ class STRING (TensorType , dtype = onnx .TensorProto .STRING ):
124
142
pass
125
143
126
144
127
- @_BuiltinTensorType (onnx .TensorProto .BOOL )
128
- class BOOL (TensorType ):
145
+ class BOOL (TensorType , dtype = onnx .TensorProto .BOOL ):
129
146
pass
130
147
131
148
132
- @_BuiltinTensorType (onnx .TensorProto .FLOAT16 )
133
- class FLOAT16 (TensorType ):
149
+ class FLOAT16 (TensorType , dtype = onnx .TensorProto .FLOAT16 ):
134
150
pass
135
151
136
152
137
- @_BuiltinTensorType (onnx .TensorProto .DOUBLE )
138
- class DOUBLE (TensorType ):
153
+ class DOUBLE (TensorType , dtype = onnx .TensorProto .DOUBLE ):
139
154
pass
140
155
141
156
142
- @_BuiltinTensorType (onnx .TensorProto .UINT32 )
143
- class UINT32 (TensorType ):
157
+ class UINT32 (TensorType , dtype = onnx .TensorProto .UINT32 ):
144
158
pass
145
159
146
160
147
- @_BuiltinTensorType (onnx .TensorProto .UINT64 )
148
- class UINT64 (TensorType ):
161
+ class UINT64 (TensorType , dtype = onnx .TensorProto .UINT64 ):
149
162
pass
150
163
151
164
152
- @_BuiltinTensorType (onnx .TensorProto .COMPLEX64 )
153
- class COMPLEX64 (TensorType ):
165
+ class COMPLEX64 (TensorType , dtype = onnx .TensorProto .COMPLEX64 ):
154
166
pass
155
167
156
168
157
- @_BuiltinTensorType (onnx .TensorProto .COMPLEX128 )
158
- class COMPLEX128 (TensorType ):
169
+ class COMPLEX128 (TensorType , dtype = onnx .TensorProto .COMPLEX128 ):
159
170
pass
160
171
161
172
162
- @_BuiltinTensorType (onnx .TensorProto .BFLOAT16 )
163
- class BFLOAT16 (TensorType ):
173
+ class BFLOAT16 (TensorType , dtype = onnx .TensorProto .BFLOAT16 ):
164
174
pass
165
175
166
176
0 commit comments