From 9755f8a13798039e985be6b2885cf68ee961b4bf Mon Sep 17 00:00:00 2001 From: Aaron Bockover Date: Thu, 1 Dec 2022 22:07:06 -0500 Subject: [PATCH 1/3] Fix TensorType to return new types for shapes instead of type instances This ensures that TensorType is never actually instantiated and instead new TensorType subclasses are created (and cached) for each shape. This is a stepping stone to a fully generic TensorType where the dtype and shape are modeled into the type itself, but it appears that won't be easily realized until we can depend on Python 3.11. Importantly, this adds unit tests to ensure if we can move to a proper generic that behavior will remain the same. Fixes #219 --- onnxscript/onnx_types.py | 156 +++++++++++++++-------------- onnxscript/test/onnx_types_test.py | 75 ++++++++++++++ 2 files changed, 156 insertions(+), 75 deletions(-) create mode 100644 onnxscript/test/onnx_types_test.py diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 8c8db3b17b..1958a2d284 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -2,20 +2,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + from __future__ import annotations -from typing import Optional, Tuple, Union +from abc import ABC +from typing import ClassVar, Optional, Tuple, Union import onnx import onnx.helper -# Representations of ONNX types in ONNX Script. -# Currently restricted to tensor types. -# Example type annotations in ONNX Script. -# x : FLOAT (a scalar-tensor of rank 0) -# x : FLOAT[...] (a tensor of unknown rank) -# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names) -# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions) +DType = onnx.TensorProto.DataType DimType = Union[int, str, type(None)] @@ -36,131 +32,141 @@ def check_shape(shape): check_dim(shape) -class TensorType: - """ONNX Script representation of a tensor type.""" +tensor_type_registry: dict[DType, TensorType] = {} +_tensor_type_shape_cache: dict[DType, TensorType] = {} + + +class TensorType(ABC): + """ONNX Script representation of a tensor type supporting shape annotations. + + A scalar-tensor of rank 0: + :: + + tensor: FLOAT + + A tensor of unknown rank: + :: + + tensor: FLOAT[...] + + A tensor of rank 2 of unknown dimensions, with symbolic names: + :: + + tensor: FLOAT['M', 'N'] + + A tensor of rank 2 of known dimensions: + :: - default_instance: Optional["TensorType"] = None + tensor: FLOAT[128, 1024] + """ + + dtype: ClassVar[DType] + shape: ClassVar[Optional[ShapeType]] + + def __new__(cls): + raise NotImplementedError("TensorTypes cannot be instantiated") - def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None: - self.dtype = dtype - self.shape = shape - if shape is not None: + def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None): + cls.dtype = dtype + cls.shape = shape + if shape is None: + existing_cls = tensor_type_registry.get(dtype) + if existing_cls is not None: + raise ValueError( + f"Invalid usage: subclass {existing_cls!r} " + f"already defined for dtype={dtype}" + ) + tensor_type_registry[dtype] = cls + else: check_shape(shape) - def __getitem__(self, shape: Optional[ShapeType]): - if self.shape is not None: + def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]: + if cls.shape is not None: raise ValueError("Invalid usage: shape already specified.") if shape is None: # Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension shape = (None,) - return TensorType(self.dtype, shape) - - def __class_getitem__(cls, shape: Optional[ShapeType]): - if cls.default_instance is None: - raise TypeError(f"{cls} does not specify a default_instance.") - # pylint erroneously flags with unsubscriptable-object if - # using subscript notation (cls.default_instance[shape]): - return cls.default_instance.__getitem__(shape) - - def to_type_proto(self) -> onnx.TypeProto: - if self.shape is None: + key = (cls.dtype, shape) + shaped_type = _tensor_type_shape_cache.get(key) + if shaped_type is None: + shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape) + _tensor_type_shape_cache[key] = shaped_type + return shaped_type + + @classmethod + def to_type_proto(cls) -> onnx.TypeProto: + if cls.shape is None: shape = () # "FLOAT" is treated as a scalar - elif self.shape is Ellipsis: + elif cls.shape is Ellipsis: shape = None # "FLOAT[...]" is a tensor of unknown rank - elif isinstance(self.shape, tuple): - shape = self.shape # example: "FLOAT[10,20]" + elif isinstance(cls.shape, tuple): + shape = cls.shape # example: "FLOAT[10,20]" else: - shape = [self.shape] # example: "FLOAT[10]" - return onnx.helper.make_tensor_type_proto(self.dtype, shape) - - -class _BuiltinTensorType: - def __init__(self, tensor_proto: onnx.TensorProto): - self.tensor_proto = tensor_proto - - def __call__(self, cls): - cls.default_instance = TensorType(self.tensor_proto) - cls.to_type_proto = cls.default_instance.to_type_proto - return cls + shape = [cls.shape] # example: "FLOAT[10]" + return onnx.helper.make_tensor_type_proto(cls.dtype, shape) -@_BuiltinTensorType(onnx.TensorProto.FLOAT) -class FLOAT(TensorType): +class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): pass -@_BuiltinTensorType(onnx.TensorProto.UINT8) -class UINT8(TensorType): +class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): pass -@_BuiltinTensorType(onnx.TensorProto.INT8) -class INT8(TensorType): +class INT8(TensorType, dtype=onnx.TensorProto.INT8): pass -@_BuiltinTensorType(onnx.TensorProto.UINT16) -class UINT16(TensorType): +class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): pass -@_BuiltinTensorType(onnx.TensorProto.INT16) -class INT16(TensorType): +class INT16(TensorType, dtype=onnx.TensorProto.INT16): pass -@_BuiltinTensorType(onnx.TensorProto.INT32) -class INT32(TensorType): +class INT32(TensorType, dtype=onnx.TensorProto.INT32): pass -@_BuiltinTensorType(onnx.TensorProto.INT64) -class INT64(TensorType): +class INT64(TensorType, dtype=onnx.TensorProto.INT64): pass -@_BuiltinTensorType(onnx.TensorProto.STRING) -class STRING(TensorType): +class STRING(TensorType, dtype=onnx.TensorProto.STRING): pass -@_BuiltinTensorType(onnx.TensorProto.BOOL) -class BOOL(TensorType): +class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): pass -@_BuiltinTensorType(onnx.TensorProto.FLOAT16) -class FLOAT16(TensorType): +class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): pass -@_BuiltinTensorType(onnx.TensorProto.DOUBLE) -class DOUBLE(TensorType): +class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): pass -@_BuiltinTensorType(onnx.TensorProto.UINT32) -class UINT32(TensorType): +class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): pass -@_BuiltinTensorType(onnx.TensorProto.UINT64) -class UINT64(TensorType): +class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): pass -@_BuiltinTensorType(onnx.TensorProto.COMPLEX64) -class COMPLEX64(TensorType): +class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): pass -@_BuiltinTensorType(onnx.TensorProto.COMPLEX128) -class COMPLEX128(TensorType): +class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): pass -@_BuiltinTensorType(onnx.TensorProto.BFLOAT16) -class BFLOAT16(TensorType): +class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): pass diff --git a/onnxscript/test/onnx_types_test.py b/onnxscript/test/onnx_types_test.py new file mode 100644 index 0000000000..03b5502c1d --- /dev/null +++ b/onnxscript/test/onnx_types_test.py @@ -0,0 +1,75 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# mypy: disable-error-code=misc + +"""Unit tests for the onnx_types module.""" + +import unittest + +from parameterized import parameterized + +from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry + + +class TestOnnxTypes(unittest.TestCase): + def test_instantiation(self): + with self.assertRaises(NotImplementedError): + TensorType() + with self.assertRaises(NotImplementedError): + FLOAT() + with self.assertRaises(NotImplementedError): + FLOAT[...]() + + @parameterized.expand(tensor_type_registry.items()) + def test_type_properties(self, dtype: DType, tensor_type: TensorType): + self.assertEqual(tensor_type.dtype, dtype) + self.assertIsNone(tensor_type.shape) + self.assertEqual(tensor_type[...].shape, ...) + self.assertEqual(tensor_type[...].dtype, dtype) + self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) + self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) + + @parameterized.expand([(dtype,) for dtype in tensor_type_registry]) + def test_dtype_bound_to_subclass(self, dtype: DType): + with self.assertRaises(ValueError): + type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype) + + def test_shaped_doesnt_reshape(self): + with self.assertRaises(ValueError): + FLOAT[1][...] # pylint: disable=pointless-statement + + @parameterized.expand( + [ + (FLOAT, FLOAT), + (FLOAT[None], FLOAT[None]), + (FLOAT[1, 2, 3], FLOAT[1, 2, 3]), + (FLOAT[1], FLOAT[1]), + (FLOAT[...], FLOAT[Ellipsis]), + (FLOAT["M"], FLOAT["M"]), + (FLOAT["M", "N"], FLOAT["M", "N"]), + (FLOAT["M", 3, 4], FLOAT["M", 3, 4]), + ] + ) + def test_shapes_are_same_type(self, a: TensorType, b: TensorType): + self.assertIs(a, b) + + @parameterized.expand( + [ + (FLOAT[0], FLOAT[None]), + (FLOAT[1, 2], FLOAT[3, 4]), + (FLOAT[2, 1], FLOAT[1, 2]), + (FLOAT["M", "N"], FLOAT["N", "M"]), + (FLOAT, DOUBLE), + (FLOAT[1], DOUBLE[1]), + (FLOAT["X"], DOUBLE["X"]), + ] + ) + def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType): + self.assertIsNot(a, b) + + +if __name__ == "__main__": + unittest.main() From 76ab4ed5f6cd2d808ffb66ef7c06c3a25b0ff66d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 7 Dec 2022 16:28:12 +0000 Subject: [PATCH 2/3] fix --- onnxscript/onnx_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 1958a2d284..da9ad8642d 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -5,7 +5,7 @@ from __future__ import annotations -from abc import ABC +import abc from typing import ClassVar, Optional, Tuple, Union import onnx @@ -36,7 +36,7 @@ def check_shape(shape): _tensor_type_shape_cache: dict[DType, TensorType] = {} -class TensorType(ABC): +class TensorType(abc.ABC): """ONNX Script representation of a tensor type supporting shape annotations. A scalar-tensor of rank 0: From fcd664e6529c91e29a8ef532d0c42ed6c92875e9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 7 Dec 2022 16:32:10 +0000 Subject: [PATCH 3/3] Fix lint --- onnxscript/test/onnx_types_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxscript/test/onnx_types_test.py b/onnxscript/test/onnx_types_test.py index 03b5502c1d..3293fa0fac 100644 --- a/onnxscript/test/onnx_types_test.py +++ b/onnxscript/test/onnx_types_test.py @@ -6,6 +6,7 @@ # mypy: disable-error-code=misc """Unit tests for the onnx_types module.""" +from __future__ import annotations import unittest @@ -24,13 +25,13 @@ def test_instantiation(self): FLOAT[...]() @parameterized.expand(tensor_type_registry.items()) - def test_type_properties(self, dtype: DType, tensor_type: TensorType): + def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]): self.assertEqual(tensor_type.dtype, dtype) self.assertIsNone(tensor_type.shape) - self.assertEqual(tensor_type[...].shape, ...) - self.assertEqual(tensor_type[...].dtype, dtype) - self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) - self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) + self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index] + self.assertEqual(tensor_type[...].dtype, dtype) # type: ignore[index] + self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) # type: ignore[index] + self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index] @parameterized.expand([(dtype,) for dtype in tensor_type_registry]) def test_dtype_bound_to_subclass(self, dtype: DType):