2
2
# Copyright (c) Microsoft Corporation. All rights reserved.
3
3
# Licensed under the MIT License.
4
4
# --------------------------------------------------------------------------
5
+
5
6
from __future__ import annotations
6
7
7
- from typing import Optional , Tuple , Union
8
+ from abc import ABC
9
+ from typing import ClassVar , Optional , Union
8
10
9
11
import onnx
10
12
import onnx .helper
11
13
12
- # Representations of ONNX types in ONNX Script.
13
- # Currently restricted to tensor types.
14
- # Example type annotations in ONNX Script.
15
- # x : FLOAT (a scalar-tensor of rank 0)
16
- # x : FLOAT[...] (a tensor of unknown rank)
17
- # x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
18
- # x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
14
+ DType = onnx .TensorProto .DataType
19
15
20
16
DimType = Union [int , str , type (None )]
21
17
@@ -25,7 +21,7 @@ def check_dim(dim):
25
21
raise TypeError (f"Invalid dimension { dim } " )
26
22
27
23
28
- ShapeType = Union [Tuple [DimType , ...], DimType , type (Ellipsis )]
24
+ ShapeType = Union [tuple [DimType , ...], DimType , type (... )]
29
25
30
26
31
27
def check_shape (shape ):
@@ -36,131 +32,141 @@ def check_shape(shape):
36
32
check_dim (shape )
37
33
38
34
39
- class TensorType :
40
- """ONNX Script representation of a tensor type."""
35
+ tensor_type_registry : dict [DType , TensorType ] = {}
36
+ _tensor_type_shape_cache : dict [DType , TensorType ] = {}
37
+
38
+
39
+ class TensorType (ABC ):
40
+ """ONNX Script representation of a tensor type supporting shape annotations.
41
+
42
+ A scalar-tensor of rank 0:
43
+ ::
44
+
45
+ tensor: FLOAT
46
+
47
+ A tensor of unknown rank:
48
+ ::
49
+
50
+ tensor: FLOAT[...]
51
+
52
+ A tensor of rank 2 of unknown dimensions, with symbolic names:
53
+ ::
54
+
55
+ tensor: FLOAT['M', 'N']
56
+
57
+ A tensor of rank 2 of known dimensions:
58
+ ::
41
59
42
- default_instance : Optional ["TensorType" ] = None
60
+ tensor: FLOAT[128, 1024]
61
+ """
62
+
63
+ dtype : ClassVar [DType ]
64
+ shape : ClassVar [Optional [ShapeType ]]
65
+
66
+ def __new__ (cls ):
67
+ raise NotImplementedError ("TensorTypes cannot be instantiated" )
43
68
44
- def __init__ (self , dtype , shape : Optional [ShapeType ] = None ) -> None :
45
- self .dtype = dtype
46
- self .shape = shape
47
- if shape is not None :
69
+ def __init_subclass__ (cls , dtype : DType , shape : Optional [ShapeType ] = None ):
70
+ cls .dtype = dtype
71
+ cls .shape = shape
72
+ if shape is None :
73
+ existing_cls = tensor_type_registry .get (dtype )
74
+ if existing_cls is not None :
75
+ raise ValueError (
76
+ f"Invalid usage: subclass { existing_cls !r} "
77
+ f"already defined for dtype={ dtype } "
78
+ )
79
+ tensor_type_registry [dtype ] = cls
80
+ else :
48
81
check_shape (shape )
49
82
50
- def __getitem__ ( self , shape : Optional [ShapeType ]):
51
- if self .shape is not None :
83
+ def __class_getitem__ ( cls , shape : Optional [ShapeType ]) -> type [ TensorType ] :
84
+ if cls .shape is not None :
52
85
raise ValueError ("Invalid usage: shape already specified." )
53
86
if shape is None :
54
87
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
55
88
shape = (None ,)
56
- return TensorType (self .dtype , shape )
57
-
58
- def __class_getitem__ (cls , shape : Optional [ShapeType ]):
59
- if cls .default_instance is None :
60
- raise TypeError (f"{ cls } does not specify a default_instance." )
61
- # pylint erroneously flags with unsubscriptable-object if
62
- # using subscript notation (cls.default_instance[shape]):
63
- return cls .default_instance .__getitem__ (shape )
64
-
65
- def to_type_proto (self ) -> onnx .TypeProto :
66
- if self .shape is None :
89
+ key = (cls .dtype , shape )
90
+ shaped_type = _tensor_type_shape_cache .get (key )
91
+ if shaped_type is None :
92
+ shaped_type = type (cls .__name__ , (TensorType ,), {}, dtype = cls .dtype , shape = shape )
93
+ _tensor_type_shape_cache [key ] = shaped_type
94
+ return shaped_type
95
+
96
+ @classmethod
97
+ def to_type_proto (cls ) -> onnx .TypeProto :
98
+ if cls .shape is None :
67
99
shape = () # "FLOAT" is treated as a scalar
68
- elif self .shape is Ellipsis :
100
+ elif cls .shape is Ellipsis :
69
101
shape = None # "FLOAT[...]" is a tensor of unknown rank
70
- elif isinstance (self .shape , tuple ):
71
- shape = self .shape # example: "FLOAT[10,20]"
102
+ elif isinstance (cls .shape , tuple ):
103
+ shape = cls .shape # example: "FLOAT[10,20]"
72
104
else :
73
- shape = [self .shape ] # example: "FLOAT[10]"
74
- return onnx .helper .make_tensor_type_proto (self .dtype , shape )
75
-
76
-
77
- class _BuiltinTensorType :
78
- def __init__ (self , tensor_proto : onnx .TensorProto ):
79
- self .tensor_proto = tensor_proto
80
-
81
- def __call__ (self , cls ):
82
- cls .default_instance = TensorType (self .tensor_proto )
83
- cls .to_type_proto = cls .default_instance .to_type_proto
84
- return cls
105
+ shape = [cls .shape ] # example: "FLOAT[10]"
106
+ return onnx .helper .make_tensor_type_proto (cls .dtype , shape )
85
107
86
108
87
- @_BuiltinTensorType (onnx .TensorProto .FLOAT )
88
- class FLOAT (TensorType ):
109
+ class FLOAT (TensorType , dtype = onnx .TensorProto .FLOAT ):
89
110
pass
90
111
91
112
92
- @_BuiltinTensorType (onnx .TensorProto .UINT8 )
93
- class UINT8 (TensorType ):
113
+ class UINT8 (TensorType , dtype = onnx .TensorProto .UINT8 ):
94
114
pass
95
115
96
116
97
- @_BuiltinTensorType (onnx .TensorProto .INT8 )
98
- class INT8 (TensorType ):
117
+ class INT8 (TensorType , dtype = onnx .TensorProto .INT8 ):
99
118
pass
100
119
101
120
102
- @_BuiltinTensorType (onnx .TensorProto .UINT16 )
103
- class UINT16 (TensorType ):
121
+ class UINT16 (TensorType , dtype = onnx .TensorProto .UINT16 ):
104
122
pass
105
123
106
124
107
- @_BuiltinTensorType (onnx .TensorProto .INT16 )
108
- class INT16 (TensorType ):
125
+ class INT16 (TensorType , dtype = onnx .TensorProto .INT16 ):
109
126
pass
110
127
111
128
112
- @_BuiltinTensorType (onnx .TensorProto .INT32 )
113
- class INT32 (TensorType ):
129
+ class INT32 (TensorType , dtype = onnx .TensorProto .INT32 ):
114
130
pass
115
131
116
132
117
- @_BuiltinTensorType (onnx .TensorProto .INT64 )
118
- class INT64 (TensorType ):
133
+ class INT64 (TensorType , dtype = onnx .TensorProto .INT64 ):
119
134
pass
120
135
121
136
122
- @_BuiltinTensorType (onnx .TensorProto .STRING )
123
- class STRING (TensorType ):
137
+ class STRING (TensorType , dtype = onnx .TensorProto .STRING ):
124
138
pass
125
139
126
140
127
- @_BuiltinTensorType (onnx .TensorProto .BOOL )
128
- class BOOL (TensorType ):
141
+ class BOOL (TensorType , dtype = onnx .TensorProto .BOOL ):
129
142
pass
130
143
131
144
132
- @_BuiltinTensorType (onnx .TensorProto .FLOAT16 )
133
- class FLOAT16 (TensorType ):
145
+ class FLOAT16 (TensorType , dtype = onnx .TensorProto .FLOAT16 ):
134
146
pass
135
147
136
148
137
- @_BuiltinTensorType (onnx .TensorProto .DOUBLE )
138
- class DOUBLE (TensorType ):
149
+ class DOUBLE (TensorType , dtype = onnx .TensorProto .DOUBLE ):
139
150
pass
140
151
141
152
142
- @_BuiltinTensorType (onnx .TensorProto .UINT32 )
143
- class UINT32 (TensorType ):
153
+ class UINT32 (TensorType , dtype = onnx .TensorProto .UINT32 ):
144
154
pass
145
155
146
156
147
- @_BuiltinTensorType (onnx .TensorProto .UINT64 )
148
- class UINT64 (TensorType ):
157
+ class UINT64 (TensorType , dtype = onnx .TensorProto .UINT64 ):
149
158
pass
150
159
151
160
152
- @_BuiltinTensorType (onnx .TensorProto .COMPLEX64 )
153
- class COMPLEX64 (TensorType ):
161
+ class COMPLEX64 (TensorType , dtype = onnx .TensorProto .COMPLEX64 ):
154
162
pass
155
163
156
164
157
- @_BuiltinTensorType (onnx .TensorProto .COMPLEX128 )
158
- class COMPLEX128 (TensorType ):
165
+ class COMPLEX128 (TensorType , dtype = onnx .TensorProto .COMPLEX128 ):
159
166
pass
160
167
161
168
162
- @_BuiltinTensorType (onnx .TensorProto .BFLOAT16 )
163
- class BFLOAT16 (TensorType ):
169
+ class BFLOAT16 (TensorType , dtype = onnx .TensorProto .BFLOAT16 ):
164
170
pass
165
171
166
172
0 commit comments