Skip to content

Commit d4ef07a

Browse files
committed
Fix TensorType to return new TensorTypes for shapes instead of TensorType instances
This ensures that TensorType is never actually instantiated and instead new TensorType subclasses are created (and cached) for each shape. This is a stepping stone to a fully TensorType where the dtype and shape are modeled into the type itself, but it appears that won't be easily realized until we can depend on Python 3.11. Importantly, this adds unit tests to ensure if we can move to a proper generic that behavior will remain the same. Fixes #219
1 parent edaf1ea commit d4ef07a

File tree

2 files changed

+157
-76
lines changed

2 files changed

+157
-76
lines changed

onnxscript/onnx_types.py

Lines changed: 82 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,16 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5+
56
from __future__ import annotations
67

7-
from typing import Optional, Tuple, Union
8+
from abc import ABC
9+
from typing import ClassVar, Optional, Union
810

911
import onnx
1012
import onnx.helper
1113

12-
# Representations of ONNX types in ONNX Script.
13-
# Currently restricted to tensor types.
14-
# Example type annotations in ONNX Script.
15-
# x : FLOAT (a scalar-tensor of rank 0)
16-
# x : FLOAT[...] (a tensor of unknown rank)
17-
# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
18-
# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
14+
DType = onnx.TensorProto.DataType
1915

2016
DimType = Union[int, str, type(None)]
2117

@@ -25,7 +21,7 @@ def check_dim(dim):
2521
raise TypeError(f"Invalid dimension {dim}")
2622

2723

28-
ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]
24+
ShapeType = Union[tuple[DimType, ...], DimType, type(...)]
2925

3026

3127
def check_shape(shape):
@@ -36,131 +32,141 @@ def check_shape(shape):
3632
check_dim(shape)
3733

3834

39-
class TensorType:
40-
"""ONNX Script representation of a tensor type."""
35+
tensor_type_registry: dict[DType, TensorType] = {}
36+
_tensor_type_shape_cache: dict[DType, TensorType] = {}
37+
38+
39+
class TensorType(ABC):
40+
"""ONNX Script representation of a tensor type supporting shape annotations.
41+
42+
A scalar-tensor of rank 0:
43+
::
44+
45+
tensor: FLOAT
46+
47+
A tensor of unknown rank:
48+
::
49+
50+
tensor: FLOAT[...]
51+
52+
A tensor of rank 2 of unknown dimensions, with symbolic names:
53+
::
54+
55+
tensor: FLOAT['M', 'N']
56+
57+
A tensor of rank 2 of known dimensions:
58+
::
4159
42-
default_instance: Optional["TensorType"] = None
60+
tensor: FLOAT[128, 1024]
61+
"""
62+
63+
dtype: ClassVar[DType]
64+
shape: ClassVar[Optional[ShapeType]]
65+
66+
def __new__(cls):
67+
raise NotImplementedError("TensorTypes cannot be instantiated")
4368

44-
def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None:
45-
self.dtype = dtype
46-
self.shape = shape
47-
if shape is not None:
69+
def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
70+
cls.dtype = dtype
71+
cls.shape = shape
72+
if shape is None:
73+
existing_cls = tensor_type_registry.get(dtype)
74+
if existing_cls is not None:
75+
raise ValueError(
76+
f"Invalid usage: subclass {existing_cls!r} "
77+
f"already defined for dtype={dtype}"
78+
)
79+
tensor_type_registry[dtype] = cls
80+
else:
4881
check_shape(shape)
4982

50-
def __getitem__(self, shape: Optional[ShapeType]):
51-
if self.shape is not None:
83+
def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]:
84+
if cls.shape is not None:
5285
raise ValueError("Invalid usage: shape already specified.")
5386
if shape is None:
5487
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
5588
shape = (None,)
56-
return TensorType(self.dtype, shape)
57-
58-
def __class_getitem__(cls, shape: Optional[ShapeType]):
59-
if cls.default_instance is None:
60-
raise TypeError(f"{cls} does not specify a default_instance.")
61-
# pylint erroneously flags with unsubscriptable-object if
62-
# using subscript notation (cls.default_instance[shape]):
63-
return cls.default_instance.__getitem__(shape)
64-
65-
def to_type_proto(self) -> onnx.TypeProto:
66-
if self.shape is None:
89+
key = (cls.dtype, shape)
90+
shaped_type = _tensor_type_shape_cache.get(key)
91+
if shaped_type is None:
92+
shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape)
93+
_tensor_type_shape_cache[key] = shaped_type
94+
return shaped_type
95+
96+
@classmethod
97+
def to_type_proto(cls) -> onnx.TypeProto:
98+
if cls.shape is None:
6799
shape = () # "FLOAT" is treated as a scalar
68-
elif self.shape is Ellipsis:
100+
elif cls.shape is Ellipsis:
69101
shape = None # "FLOAT[...]" is a tensor of unknown rank
70-
elif isinstance(self.shape, tuple):
71-
shape = self.shape # example: "FLOAT[10,20]"
102+
elif isinstance(cls.shape, tuple):
103+
shape = cls.shape # example: "FLOAT[10,20]"
72104
else:
73-
shape = [self.shape] # example: "FLOAT[10]"
74-
return onnx.helper.make_tensor_type_proto(self.dtype, shape)
75-
76-
77-
class _BuiltinTensorType:
78-
def __init__(self, tensor_proto: onnx.TensorProto):
79-
self.tensor_proto = tensor_proto
80-
81-
def __call__(self, cls):
82-
cls.default_instance = TensorType(self.tensor_proto)
83-
cls.to_type_proto = cls.default_instance.to_type_proto
84-
return cls
105+
shape = [cls.shape] # example: "FLOAT[10]"
106+
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
85107

86108

87-
@_BuiltinTensorType(onnx.TensorProto.FLOAT)
88-
class FLOAT(TensorType):
109+
class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
89110
pass
90111

91112

92-
@_BuiltinTensorType(onnx.TensorProto.UINT8)
93-
class UINT8(TensorType):
113+
class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
94114
pass
95115

96116

97-
@_BuiltinTensorType(onnx.TensorProto.INT8)
98-
class INT8(TensorType):
117+
class INT8(TensorType, dtype=onnx.TensorProto.INT8):
99118
pass
100119

101120

102-
@_BuiltinTensorType(onnx.TensorProto.UINT16)
103-
class UINT16(TensorType):
121+
class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
104122
pass
105123

106124

107-
@_BuiltinTensorType(onnx.TensorProto.INT16)
108-
class INT16(TensorType):
125+
class INT16(TensorType, dtype=onnx.TensorProto.INT16):
109126
pass
110127

111128

112-
@_BuiltinTensorType(onnx.TensorProto.INT32)
113-
class INT32(TensorType):
129+
class INT32(TensorType, dtype=onnx.TensorProto.INT32):
114130
pass
115131

116132

117-
@_BuiltinTensorType(onnx.TensorProto.INT64)
118-
class INT64(TensorType):
133+
class INT64(TensorType, dtype=onnx.TensorProto.INT64):
119134
pass
120135

121136

122-
@_BuiltinTensorType(onnx.TensorProto.STRING)
123-
class STRING(TensorType):
137+
class STRING(TensorType, dtype=onnx.TensorProto.STRING):
124138
pass
125139

126140

127-
@_BuiltinTensorType(onnx.TensorProto.BOOL)
128-
class BOOL(TensorType):
141+
class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
129142
pass
130143

131144

132-
@_BuiltinTensorType(onnx.TensorProto.FLOAT16)
133-
class FLOAT16(TensorType):
145+
class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
134146
pass
135147

136148

137-
@_BuiltinTensorType(onnx.TensorProto.DOUBLE)
138-
class DOUBLE(TensorType):
149+
class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
139150
pass
140151

141152

142-
@_BuiltinTensorType(onnx.TensorProto.UINT32)
143-
class UINT32(TensorType):
153+
class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
144154
pass
145155

146156

147-
@_BuiltinTensorType(onnx.TensorProto.UINT64)
148-
class UINT64(TensorType):
157+
class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
149158
pass
150159

151160

152-
@_BuiltinTensorType(onnx.TensorProto.COMPLEX64)
153-
class COMPLEX64(TensorType):
161+
class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
154162
pass
155163

156164

157-
@_BuiltinTensorType(onnx.TensorProto.COMPLEX128)
158-
class COMPLEX128(TensorType):
165+
class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
159166
pass
160167

161168

162-
@_BuiltinTensorType(onnx.TensorProto.BFLOAT16)
163-
class BFLOAT16(TensorType):
169+
class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
164170
pass
165171

166172

onnxscript/test/onnx_types_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
#
6+
# mypy: disable-error-code=misc
7+
8+
"""Unit tests for the onnx_types module."""
9+
10+
import unittest
11+
12+
from parameterized import parameterized
13+
14+
from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry
15+
16+
17+
class TestOnnxTypes(unittest.TestCase):
18+
def test_instantiation(self):
19+
with self.assertRaises(NotImplementedError):
20+
TensorType()
21+
with self.assertRaises(NotImplementedError):
22+
FLOAT()
23+
with self.assertRaises(NotImplementedError):
24+
FLOAT[...]()
25+
26+
@parameterized.expand(tensor_type_registry.items())
27+
def test_type_properties(self, dtype: DType, tensor_type: TensorType):
28+
self.assertEqual(tensor_type.dtype, dtype)
29+
self.assertIsNone(tensor_type.shape)
30+
self.assertEqual(tensor_type[...].shape, ...)
31+
self.assertEqual(tensor_type[...].dtype, dtype)
32+
self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3))
33+
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype)
34+
35+
@parameterized.expand([(dtype,) for dtype in tensor_type_registry])
36+
def test_dtype_bound_to_subclass(self, dtype: DType):
37+
with self.assertRaises(ValueError):
38+
type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)
39+
40+
def test_shaped_doesnt_reshape(self):
41+
with self.assertRaises(ValueError):
42+
FLOAT[1][...] # pylint: disable=pointless-statement
43+
44+
@parameterized.expand(
45+
[
46+
(FLOAT, FLOAT),
47+
(FLOAT[None], FLOAT[None]),
48+
(FLOAT[1, 2, 3], FLOAT[1, 2, 3]),
49+
(FLOAT[1], FLOAT[1]),
50+
(FLOAT[...], FLOAT[Ellipsis]),
51+
(FLOAT["M"], FLOAT["M"]),
52+
(FLOAT["M", "N"], FLOAT["M", "N"]),
53+
(FLOAT["M", 3, 4], FLOAT["M", 3, 4]),
54+
]
55+
)
56+
def test_shapes_are_same_type(self, a: TensorType, b: TensorType):
57+
self.assertIs(a, b)
58+
59+
@parameterized.expand(
60+
[
61+
(FLOAT[0], FLOAT[None]),
62+
(FLOAT[1, 2], FLOAT[3, 4]),
63+
(FLOAT[2, 1], FLOAT[1, 2]),
64+
(FLOAT["M", "N"], FLOAT["N", "M"]),
65+
(FLOAT, DOUBLE),
66+
(FLOAT[1], DOUBLE[1]),
67+
(FLOAT["X"], DOUBLE["X"]),
68+
]
69+
)
70+
def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType):
71+
self.assertIsNot(a, b)
72+
73+
74+
if __name__ == "__main__":
75+
unittest.main()

0 commit comments

Comments
 (0)