Skip to content

Commit 87c8432

Browse files
committed
Fix TensorType to return new TensorTypes for shapes instead of TensorType instances
This ensures that TensorType is never actually instantiated and instead new TensorType subclasses are created (and cached) for each shape. This is a stepping stone to a fully TensorType where the dtype and shape are modeled into the type itself, but it appears that won't be easily realized until we can depend on Python 3.11. Importantly, this adds unit tests to ensure if we can move to a proper generic that behavior will remain the same. Fixes #219
1 parent edaf1ea commit 87c8432

File tree

2 files changed

+157
-74
lines changed

2 files changed

+157
-74
lines changed

onnxscript/onnx_types.py

+84-74
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,17 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5+
56
from __future__ import annotations
67

7-
from typing import Optional, Tuple, Union
8+
from abc import ABC
9+
from typing import ClassVar, Optional, Tuple, Union
810

911
import onnx
1012
import onnx.helper
1113

12-
# Representations of ONNX types in ONNX Script.
13-
# Currently restricted to tensor types.
14-
# Example type annotations in ONNX Script.
15-
# x : FLOAT (a scalar-tensor of rank 0)
16-
# x : FLOAT[...] (a tensor of unknown rank)
17-
# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
18-
# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
14+
DType = int
15+
1916

2017
DimType = Union[int, str, type(None)]
2118

@@ -36,131 +33,144 @@ def check_shape(shape):
3633
check_dim(shape)
3734

3835

39-
class TensorType:
40-
"""ONNX Script representation of a tensor type."""
36+
tensor_type_registry: dict[DType, "TensorType"] = {}
37+
tensor_type_shape_cache: dict[DType, "TensorType"] = {}
38+
39+
40+
class TensorType(ABC):
41+
"""ONNX Script representation of a tensor type supporting shape annotations.
42+
43+
A scalar-tensor of rank 0:
44+
::
45+
46+
tensor: FLOAT
47+
48+
A tensor of unknown rank:
49+
::
50+
51+
tensor: FLOAT[...]
52+
53+
A tensor of rank 2 of unknown dimensions, with symbolic names:
54+
::
55+
56+
tensor: FLOAT['M', 'N']
57+
58+
A tensor of rank 2 of known dimensions:
59+
::
4160
42-
default_instance: Optional["TensorType"] = None
61+
tensor: FLOAT[128, 1024]
62+
"""
63+
64+
dtype: ClassVar[DType]
65+
shape: ClassVar[Optional[ShapeType]]
66+
67+
def __new__(cls):
68+
raise NotImplementedError("TensorTypes cannot be instantiated")
4369

44-
def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None:
45-
self.dtype = dtype
46-
self.shape = shape
47-
if shape is not None:
70+
def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
71+
cls.dtype = dtype
72+
cls.shape = shape
73+
if shape is None:
74+
existing_cls = tensor_type_registry.get(dtype)
75+
if existing_cls is not None:
76+
raise ValueError(
77+
f"Invalid usage: subclass {existing_cls!r} "
78+
f"already defined for dtype={dtype}"
79+
)
80+
tensor_type_registry[dtype] = cls
81+
else:
4882
check_shape(shape)
4983

5084
def __getitem__(self, shape: Optional[ShapeType]):
51-
if self.shape is not None:
85+
raise NotImplementedError("should not be reached")
86+
87+
def __class_getitem__(cls, shape: Optional[ShapeType]) -> TensorType:
88+
if cls.shape is not None:
5289
raise ValueError("Invalid usage: shape already specified.")
5390
if shape is None:
5491
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
5592
shape = (None,)
56-
return TensorType(self.dtype, shape)
57-
58-
def __class_getitem__(cls, shape: Optional[ShapeType]):
59-
if cls.default_instance is None:
60-
raise TypeError(f"{cls} does not specify a default_instance.")
61-
# pylint erroneously flags with unsubscriptable-object if
62-
# using subscript notation (cls.default_instance[shape]):
63-
return cls.default_instance.__getitem__(shape)
64-
65-
def to_type_proto(self) -> onnx.TypeProto:
66-
if self.shape is None:
93+
key = (cls.dtype, shape)
94+
shaped_type = tensor_type_shape_cache.get(key)
95+
if shaped_type is None:
96+
shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape)
97+
tensor_type_shape_cache[key] = shaped_type
98+
return shaped_type
99+
100+
@classmethod
101+
def to_type_proto(cls) -> onnx.TypeProto:
102+
if cls.shape is None:
67103
shape = () # "FLOAT" is treated as a scalar
68-
elif self.shape is Ellipsis:
104+
elif cls.shape is Ellipsis:
69105
shape = None # "FLOAT[...]" is a tensor of unknown rank
70-
elif isinstance(self.shape, tuple):
71-
shape = self.shape # example: "FLOAT[10,20]"
106+
elif isinstance(cls.shape, tuple):
107+
shape = cls.shape # example: "FLOAT[10,20]"
72108
else:
73-
shape = [self.shape] # example: "FLOAT[10]"
74-
return onnx.helper.make_tensor_type_proto(self.dtype, shape)
75-
76-
77-
class _BuiltinTensorType:
78-
def __init__(self, tensor_proto: onnx.TensorProto):
79-
self.tensor_proto = tensor_proto
80-
81-
def __call__(self, cls):
82-
cls.default_instance = TensorType(self.tensor_proto)
83-
cls.to_type_proto = cls.default_instance.to_type_proto
84-
return cls
109+
shape = [cls.shape] # example: "FLOAT[10]"
110+
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
85111

86112

87-
@_BuiltinTensorType(onnx.TensorProto.FLOAT)
88-
class FLOAT(TensorType):
113+
class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
89114
pass
90115

91116

92-
@_BuiltinTensorType(onnx.TensorProto.UINT8)
93-
class UINT8(TensorType):
117+
class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
94118
pass
95119

96120

97-
@_BuiltinTensorType(onnx.TensorProto.INT8)
98-
class INT8(TensorType):
121+
class INT8(TensorType, dtype=onnx.TensorProto.INT8):
99122
pass
100123

101124

102-
@_BuiltinTensorType(onnx.TensorProto.UINT16)
103-
class UINT16(TensorType):
125+
class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
104126
pass
105127

106128

107-
@_BuiltinTensorType(onnx.TensorProto.INT16)
108-
class INT16(TensorType):
129+
class INT16(TensorType, dtype=onnx.TensorProto.INT16):
109130
pass
110131

111132

112-
@_BuiltinTensorType(onnx.TensorProto.INT32)
113-
class INT32(TensorType):
133+
class INT32(TensorType, dtype=onnx.TensorProto.INT32):
114134
pass
115135

116136

117-
@_BuiltinTensorType(onnx.TensorProto.INT64)
118-
class INT64(TensorType):
137+
class INT64(TensorType, dtype=onnx.TensorProto.INT64):
119138
pass
120139

121140

122-
@_BuiltinTensorType(onnx.TensorProto.STRING)
123-
class STRING(TensorType):
141+
class STRING(TensorType, dtype=onnx.TensorProto.STRING):
124142
pass
125143

126144

127-
@_BuiltinTensorType(onnx.TensorProto.BOOL)
128-
class BOOL(TensorType):
145+
class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
129146
pass
130147

131148

132-
@_BuiltinTensorType(onnx.TensorProto.FLOAT16)
133-
class FLOAT16(TensorType):
149+
class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
134150
pass
135151

136152

137-
@_BuiltinTensorType(onnx.TensorProto.DOUBLE)
138-
class DOUBLE(TensorType):
153+
class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
139154
pass
140155

141156

142-
@_BuiltinTensorType(onnx.TensorProto.UINT32)
143-
class UINT32(TensorType):
157+
class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
144158
pass
145159

146160

147-
@_BuiltinTensorType(onnx.TensorProto.UINT64)
148-
class UINT64(TensorType):
161+
class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
149162
pass
150163

151164

152-
@_BuiltinTensorType(onnx.TensorProto.COMPLEX64)
153-
class COMPLEX64(TensorType):
165+
class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
154166
pass
155167

156168

157-
@_BuiltinTensorType(onnx.TensorProto.COMPLEX128)
158-
class COMPLEX128(TensorType):
169+
class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
159170
pass
160171

161172

162-
@_BuiltinTensorType(onnx.TensorProto.BFLOAT16)
163-
class BFLOAT16(TensorType):
173+
class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
164174
pass
165175

166176

onnxscript/test/onnx_types_test.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
"""Unit tests for the onnx_types module."""
7+
8+
import unittest
9+
10+
from parameterized import parameterized
11+
12+
from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry
13+
14+
15+
class TestOnnxTypes(unittest.TestCase):
16+
def test_instantiation(self):
17+
with self.assertRaises(NotImplementedError):
18+
TensorType()
19+
with self.assertRaises(NotImplementedError):
20+
FLOAT()
21+
with self.assertRaises(NotImplementedError):
22+
FLOAT[...]()
23+
24+
@parameterized.expand(tensor_type_registry.items())
25+
def test_type_properties(self, dtype: DType, tensor_type: TensorType):
26+
self.assertEqual(tensor_type.dtype, dtype)
27+
self.assertIsNone(tensor_type.shape)
28+
self.assertEqual(tensor_type[...].shape, ...)
29+
self.assertEqual(tensor_type[...].dtype, dtype)
30+
self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3))
31+
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype)
32+
33+
@parameterized.expand([(dtype,) for dtype in tensor_type_registry.keys()])
34+
def test_dtype_bound_to_subclass(self, dtype: DType):
35+
with self.assertRaises(ValueError):
36+
type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)
37+
38+
def test_shaped_doesnt_reshape(self):
39+
with self.assertRaises(ValueError):
40+
FLOAT[1][...]
41+
42+
@parameterized.expand(
43+
[
44+
(FLOAT, FLOAT),
45+
(FLOAT[None], FLOAT[None]),
46+
(FLOAT[1, 2, 3], FLOAT[1, 2, 3]),
47+
(FLOAT[1], FLOAT[1]),
48+
(FLOAT[...], FLOAT[Ellipsis]),
49+
(FLOAT["M"], FLOAT["M"]),
50+
(FLOAT["M", "N"], FLOAT["M", "N"]),
51+
(FLOAT["M", 3, 4], FLOAT["M", 3, 4]),
52+
]
53+
)
54+
def test_shapes_are_same_type(self, a: TensorType, b: TensorType):
55+
self.assertIs(a, b)
56+
57+
@parameterized.expand(
58+
[
59+
(FLOAT[0], FLOAT[None]),
60+
(FLOAT[1, 2], FLOAT[3, 4]),
61+
(FLOAT[2, 1], FLOAT[1, 2]),
62+
(FLOAT["M", "N"], FLOAT["N", "M"]),
63+
(FLOAT, DOUBLE),
64+
(FLOAT[1], DOUBLE[1]),
65+
(FLOAT["X"], DOUBLE["X"]),
66+
]
67+
)
68+
def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType):
69+
self.assertIsNot(a, b)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)