diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 8c8db3b17b..da9ad8642d 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -2,20 +2,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + from __future__ import annotations -from typing import Optional, Tuple, Union +import abc +from typing import ClassVar, Optional, Tuple, Union import onnx import onnx.helper -# Representations of ONNX types in ONNX Script. -# Currently restricted to tensor types. -# Example type annotations in ONNX Script. -# x : FLOAT (a scalar-tensor of rank 0) -# x : FLOAT[...] (a tensor of unknown rank) -# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names) -# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions) +DType = onnx.TensorProto.DataType DimType = Union[int, str, type(None)] @@ -36,131 +32,141 @@ def check_shape(shape): check_dim(shape) -class TensorType: - """ONNX Script representation of a tensor type.""" +tensor_type_registry: dict[DType, TensorType] = {} +_tensor_type_shape_cache: dict[DType, TensorType] = {} + + +class TensorType(abc.ABC): + """ONNX Script representation of a tensor type supporting shape annotations. + + A scalar-tensor of rank 0: + :: + + tensor: FLOAT + + A tensor of unknown rank: + :: + + tensor: FLOAT[...] + + A tensor of rank 2 of unknown dimensions, with symbolic names: + :: + + tensor: FLOAT['M', 'N'] + + A tensor of rank 2 of known dimensions: + :: - default_instance: Optional["TensorType"] = None + tensor: FLOAT[128, 1024] + """ + + dtype: ClassVar[DType] + shape: ClassVar[Optional[ShapeType]] + + def __new__(cls): + raise NotImplementedError("TensorTypes cannot be instantiated") - def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None: - self.dtype = dtype - self.shape = shape - if shape is not None: + def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None): + cls.dtype = dtype + cls.shape = shape + if shape is None: + existing_cls = tensor_type_registry.get(dtype) + if existing_cls is not None: + raise ValueError( + f"Invalid usage: subclass {existing_cls!r} " + f"already defined for dtype={dtype}" + ) + tensor_type_registry[dtype] = cls + else: check_shape(shape) - def __getitem__(self, shape: Optional[ShapeType]): - if self.shape is not None: + def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]: + if cls.shape is not None: raise ValueError("Invalid usage: shape already specified.") if shape is None: # Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension shape = (None,) - return TensorType(self.dtype, shape) - - def __class_getitem__(cls, shape: Optional[ShapeType]): - if cls.default_instance is None: - raise TypeError(f"{cls} does not specify a default_instance.") - # pylint erroneously flags with unsubscriptable-object if - # using subscript notation (cls.default_instance[shape]): - return cls.default_instance.__getitem__(shape) - - def to_type_proto(self) -> onnx.TypeProto: - if self.shape is None: + key = (cls.dtype, shape) + shaped_type = _tensor_type_shape_cache.get(key) + if shaped_type is None: + shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape) + _tensor_type_shape_cache[key] = shaped_type + return shaped_type + + @classmethod + def to_type_proto(cls) -> onnx.TypeProto: + if cls.shape is None: shape = () # "FLOAT" is treated as a scalar - elif self.shape is Ellipsis: + elif cls.shape is Ellipsis: shape = None # "FLOAT[...]" is a tensor of unknown rank - elif isinstance(self.shape, tuple): - shape = self.shape # example: "FLOAT[10,20]" + elif isinstance(cls.shape, tuple): + shape = cls.shape # example: "FLOAT[10,20]" else: - shape = [self.shape] # example: "FLOAT[10]" - return onnx.helper.make_tensor_type_proto(self.dtype, shape) - - -class _BuiltinTensorType: - def __init__(self, tensor_proto: onnx.TensorProto): - self.tensor_proto = tensor_proto - - def __call__(self, cls): - cls.default_instance = TensorType(self.tensor_proto) - cls.to_type_proto = cls.default_instance.to_type_proto - return cls + shape = [cls.shape] # example: "FLOAT[10]" + return onnx.helper.make_tensor_type_proto(cls.dtype, shape) -@_BuiltinTensorType(onnx.TensorProto.FLOAT) -class FLOAT(TensorType): +class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): pass -@_BuiltinTensorType(onnx.TensorProto.UINT8) -class UINT8(TensorType): +class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): pass -@_BuiltinTensorType(onnx.TensorProto.INT8) -class INT8(TensorType): +class INT8(TensorType, dtype=onnx.TensorProto.INT8): pass -@_BuiltinTensorType(onnx.TensorProto.UINT16) -class UINT16(TensorType): +class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): pass -@_BuiltinTensorType(onnx.TensorProto.INT16) -class INT16(TensorType): +class INT16(TensorType, dtype=onnx.TensorProto.INT16): pass -@_BuiltinTensorType(onnx.TensorProto.INT32) -class INT32(TensorType): +class INT32(TensorType, dtype=onnx.TensorProto.INT32): pass -@_BuiltinTensorType(onnx.TensorProto.INT64) -class INT64(TensorType): +class INT64(TensorType, dtype=onnx.TensorProto.INT64): pass -@_BuiltinTensorType(onnx.TensorProto.STRING) -class STRING(TensorType): +class STRING(TensorType, dtype=onnx.TensorProto.STRING): pass -@_BuiltinTensorType(onnx.TensorProto.BOOL) -class BOOL(TensorType): +class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): pass -@_BuiltinTensorType(onnx.TensorProto.FLOAT16) -class FLOAT16(TensorType): +class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): pass -@_BuiltinTensorType(onnx.TensorProto.DOUBLE) -class DOUBLE(TensorType): +class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): pass -@_BuiltinTensorType(onnx.TensorProto.UINT32) -class UINT32(TensorType): +class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): pass -@_BuiltinTensorType(onnx.TensorProto.UINT64) -class UINT64(TensorType): +class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): pass -@_BuiltinTensorType(onnx.TensorProto.COMPLEX64) -class COMPLEX64(TensorType): +class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): pass -@_BuiltinTensorType(onnx.TensorProto.COMPLEX128) -class COMPLEX128(TensorType): +class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): pass -@_BuiltinTensorType(onnx.TensorProto.BFLOAT16) -class BFLOAT16(TensorType): +class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): pass diff --git a/onnxscript/test/onnx_types_test.py b/onnxscript/test/onnx_types_test.py new file mode 100644 index 0000000000..3293fa0fac --- /dev/null +++ b/onnxscript/test/onnx_types_test.py @@ -0,0 +1,76 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# mypy: disable-error-code=misc + +"""Unit tests for the onnx_types module.""" +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry + + +class TestOnnxTypes(unittest.TestCase): + def test_instantiation(self): + with self.assertRaises(NotImplementedError): + TensorType() + with self.assertRaises(NotImplementedError): + FLOAT() + with self.assertRaises(NotImplementedError): + FLOAT[...]() + + @parameterized.expand(tensor_type_registry.items()) + def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]): + self.assertEqual(tensor_type.dtype, dtype) + self.assertIsNone(tensor_type.shape) + self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index] + self.assertEqual(tensor_type[...].dtype, dtype) # type: ignore[index] + self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) # type: ignore[index] + self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index] + + @parameterized.expand([(dtype,) for dtype in tensor_type_registry]) + def test_dtype_bound_to_subclass(self, dtype: DType): + with self.assertRaises(ValueError): + type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype) + + def test_shaped_doesnt_reshape(self): + with self.assertRaises(ValueError): + FLOAT[1][...] # pylint: disable=pointless-statement + + @parameterized.expand( + [ + (FLOAT, FLOAT), + (FLOAT[None], FLOAT[None]), + (FLOAT[1, 2, 3], FLOAT[1, 2, 3]), + (FLOAT[1], FLOAT[1]), + (FLOAT[...], FLOAT[Ellipsis]), + (FLOAT["M"], FLOAT["M"]), + (FLOAT["M", "N"], FLOAT["M", "N"]), + (FLOAT["M", 3, 4], FLOAT["M", 3, 4]), + ] + ) + def test_shapes_are_same_type(self, a: TensorType, b: TensorType): + self.assertIs(a, b) + + @parameterized.expand( + [ + (FLOAT[0], FLOAT[None]), + (FLOAT[1, 2], FLOAT[3, 4]), + (FLOAT[2, 1], FLOAT[1, 2]), + (FLOAT["M", "N"], FLOAT["N", "M"]), + (FLOAT, DOUBLE), + (FLOAT[1], DOUBLE[1]), + (FLOAT["X"], DOUBLE["X"]), + ] + ) + def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType): + self.assertIsNot(a, b) + + +if __name__ == "__main__": + unittest.main()