22# Copyright (c) Microsoft Corporation. All rights reserved.
33# Licensed under the MIT License.
44# --------------------------------------------------------------------------
5+
56from __future__ import annotations
67
7- from typing import Optional , Tuple , Union
8+ import abc
9+ from typing import ClassVar , Optional , Tuple , Union
810
911import onnx
1012import onnx .helper
1113
12- # Representations of ONNX types in ONNX Script.
13- # Currently restricted to tensor types.
14- # Example type annotations in ONNX Script.
15- # x : FLOAT (a scalar-tensor of rank 0)
16- # x : FLOAT[...] (a tensor of unknown rank)
17- # x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
18- # x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
14+ DType = onnx .TensorProto .DataType
1915
2016DimType = Union [int , str , type (None )]
2117
@@ -36,131 +32,141 @@ def check_shape(shape):
3632 check_dim (shape )
3733
3834
39- class TensorType :
40- """ONNX Script representation of a tensor type."""
35+ tensor_type_registry : dict [DType , TensorType ] = {}
36+ _tensor_type_shape_cache : dict [DType , TensorType ] = {}
37+
38+
39+ class TensorType (abc .ABC ):
40+ """ONNX Script representation of a tensor type supporting shape annotations.
41+
42+ A scalar-tensor of rank 0:
43+ ::
44+
45+ tensor: FLOAT
46+
47+ A tensor of unknown rank:
48+ ::
49+
50+ tensor: FLOAT[...]
51+
52+ A tensor of rank 2 of unknown dimensions, with symbolic names:
53+ ::
54+
55+ tensor: FLOAT['M', 'N']
56+
57+ A tensor of rank 2 of known dimensions:
58+ ::
4159
42- default_instance : Optional ["TensorType" ] = None
60+ tensor: FLOAT[128, 1024]
61+ """
62+
63+ dtype : ClassVar [DType ]
64+ shape : ClassVar [Optional [ShapeType ]]
65+
66+ def __new__ (cls ):
67+ raise NotImplementedError ("TensorTypes cannot be instantiated" )
4368
44- def __init__ (self , dtype , shape : Optional [ShapeType ] = None ) -> None :
45- self .dtype = dtype
46- self .shape = shape
47- if shape is not None :
69+ def __init_subclass__ (cls , dtype : DType , shape : Optional [ShapeType ] = None ):
70+ cls .dtype = dtype
71+ cls .shape = shape
72+ if shape is None :
73+ existing_cls = tensor_type_registry .get (dtype )
74+ if existing_cls is not None :
75+ raise ValueError (
76+ f"Invalid usage: subclass { existing_cls !r} "
77+ f"already defined for dtype={ dtype } "
78+ )
79+ tensor_type_registry [dtype ] = cls
80+ else :
4881 check_shape (shape )
4982
50- def __getitem__ ( self , shape : Optional [ShapeType ]):
51- if self .shape is not None :
83+ def __class_getitem__ ( cls , shape : Optional [ShapeType ]) -> type [ TensorType ] :
84+ if cls .shape is not None :
5285 raise ValueError ("Invalid usage: shape already specified." )
5386 if shape is None :
5487 # Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
5588 shape = (None ,)
56- return TensorType (self .dtype , shape )
57-
58- def __class_getitem__ (cls , shape : Optional [ShapeType ]):
59- if cls .default_instance is None :
60- raise TypeError (f"{ cls } does not specify a default_instance." )
61- # pylint erroneously flags with unsubscriptable-object if
62- # using subscript notation (cls.default_instance[shape]):
63- return cls .default_instance .__getitem__ (shape )
64-
65- def to_type_proto (self ) -> onnx .TypeProto :
66- if self .shape is None :
89+ key = (cls .dtype , shape )
90+ shaped_type = _tensor_type_shape_cache .get (key )
91+ if shaped_type is None :
92+ shaped_type = type (cls .__name__ , (TensorType ,), {}, dtype = cls .dtype , shape = shape )
93+ _tensor_type_shape_cache [key ] = shaped_type
94+ return shaped_type
95+
96+ @classmethod
97+ def to_type_proto (cls ) -> onnx .TypeProto :
98+ if cls .shape is None :
6799 shape = () # "FLOAT" is treated as a scalar
68- elif self .shape is Ellipsis :
100+ elif cls .shape is Ellipsis :
69101 shape = None # "FLOAT[...]" is a tensor of unknown rank
70- elif isinstance (self .shape , tuple ):
71- shape = self .shape # example: "FLOAT[10,20]"
102+ elif isinstance (cls .shape , tuple ):
103+ shape = cls .shape # example: "FLOAT[10,20]"
72104 else :
73- shape = [self .shape ] # example: "FLOAT[10]"
74- return onnx .helper .make_tensor_type_proto (self .dtype , shape )
75-
76-
77- class _BuiltinTensorType :
78- def __init__ (self , tensor_proto : onnx .TensorProto ):
79- self .tensor_proto = tensor_proto
80-
81- def __call__ (self , cls ):
82- cls .default_instance = TensorType (self .tensor_proto )
83- cls .to_type_proto = cls .default_instance .to_type_proto
84- return cls
105+ shape = [cls .shape ] # example: "FLOAT[10]"
106+ return onnx .helper .make_tensor_type_proto (cls .dtype , shape )
85107
86108
87- @_BuiltinTensorType (onnx .TensorProto .FLOAT )
88- class FLOAT (TensorType ):
109+ class FLOAT (TensorType , dtype = onnx .TensorProto .FLOAT ):
89110 pass
90111
91112
92- @_BuiltinTensorType (onnx .TensorProto .UINT8 )
93- class UINT8 (TensorType ):
113+ class UINT8 (TensorType , dtype = onnx .TensorProto .UINT8 ):
94114 pass
95115
96116
97- @_BuiltinTensorType (onnx .TensorProto .INT8 )
98- class INT8 (TensorType ):
117+ class INT8 (TensorType , dtype = onnx .TensorProto .INT8 ):
99118 pass
100119
101120
102- @_BuiltinTensorType (onnx .TensorProto .UINT16 )
103- class UINT16 (TensorType ):
121+ class UINT16 (TensorType , dtype = onnx .TensorProto .UINT16 ):
104122 pass
105123
106124
107- @_BuiltinTensorType (onnx .TensorProto .INT16 )
108- class INT16 (TensorType ):
125+ class INT16 (TensorType , dtype = onnx .TensorProto .INT16 ):
109126 pass
110127
111128
112- @_BuiltinTensorType (onnx .TensorProto .INT32 )
113- class INT32 (TensorType ):
129+ class INT32 (TensorType , dtype = onnx .TensorProto .INT32 ):
114130 pass
115131
116132
117- @_BuiltinTensorType (onnx .TensorProto .INT64 )
118- class INT64 (TensorType ):
133+ class INT64 (TensorType , dtype = onnx .TensorProto .INT64 ):
119134 pass
120135
121136
122- @_BuiltinTensorType (onnx .TensorProto .STRING )
123- class STRING (TensorType ):
137+ class STRING (TensorType , dtype = onnx .TensorProto .STRING ):
124138 pass
125139
126140
127- @_BuiltinTensorType (onnx .TensorProto .BOOL )
128- class BOOL (TensorType ):
141+ class BOOL (TensorType , dtype = onnx .TensorProto .BOOL ):
129142 pass
130143
131144
132- @_BuiltinTensorType (onnx .TensorProto .FLOAT16 )
133- class FLOAT16 (TensorType ):
145+ class FLOAT16 (TensorType , dtype = onnx .TensorProto .FLOAT16 ):
134146 pass
135147
136148
137- @_BuiltinTensorType (onnx .TensorProto .DOUBLE )
138- class DOUBLE (TensorType ):
149+ class DOUBLE (TensorType , dtype = onnx .TensorProto .DOUBLE ):
139150 pass
140151
141152
142- @_BuiltinTensorType (onnx .TensorProto .UINT32 )
143- class UINT32 (TensorType ):
153+ class UINT32 (TensorType , dtype = onnx .TensorProto .UINT32 ):
144154 pass
145155
146156
147- @_BuiltinTensorType (onnx .TensorProto .UINT64 )
148- class UINT64 (TensorType ):
157+ class UINT64 (TensorType , dtype = onnx .TensorProto .UINT64 ):
149158 pass
150159
151160
152- @_BuiltinTensorType (onnx .TensorProto .COMPLEX64 )
153- class COMPLEX64 (TensorType ):
161+ class COMPLEX64 (TensorType , dtype = onnx .TensorProto .COMPLEX64 ):
154162 pass
155163
156164
157- @_BuiltinTensorType (onnx .TensorProto .COMPLEX128 )
158- class COMPLEX128 (TensorType ):
165+ class COMPLEX128 (TensorType , dtype = onnx .TensorProto .COMPLEX128 ):
159166 pass
160167
161168
162- @_BuiltinTensorType (onnx .TensorProto .BFLOAT16 )
163- class BFLOAT16 (TensorType ):
169+ class BFLOAT16 (TensorType , dtype = onnx .TensorProto .BFLOAT16 ):
164170 pass
165171
166172
0 commit comments