diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 96c840e307..15e17c0ffc 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -62,7 +62,7 @@ def __repr__(self) -> str: class Var: - def __init__(self, varname, typeinfo, info) -> None: + def __init__(self, varname, typeinfo: ONNXType, info) -> None: if not isinstance(varname, str): raise ValueError(f"varname must be a string not {type(varname)!r}.") self.name = varname diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 8c8db3b17b..7cffdaf764 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -2,25 +2,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + from __future__ import annotations -from typing import Optional, Tuple, Union +from typing import ClassVar, Optional, Tuple, Union import onnx import onnx.helper -# Representations of ONNX types in ONNX Script. -# Currently restricted to tensor types. -# Example type annotations in ONNX Script. -# x : FLOAT (a scalar-tensor of rank 0) -# x : FLOAT[...] (a tensor of unknown rank) -# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names) -# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions) +DType = onnx.TensorProto.DataType DimType = Union[int, str, type(None)] -def check_dim(dim): +def _check_dim(dim): if not isinstance(dim, (int, str, type(None))): raise TypeError(f"Invalid dimension {dim}") @@ -28,140 +23,184 @@ def check_dim(dim): ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)] -def check_shape(shape): +def _check_shape(shape): if isinstance(shape, tuple): for dim in shape: - check_dim(dim) + _check_dim(dim) elif shape != Ellipsis: - check_dim(shape) - + _check_dim(shape) -class TensorType: - """ONNX Script representation of a tensor type.""" - default_instance: Optional["TensorType"] = None +_tensor_type_shape_cache: dict[DType, TensorType] = {} - def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None: - self.dtype = dtype - self.shape = shape - if shape is not None: - check_shape(shape) - def __getitem__(self, shape: Optional[ShapeType]): - if self.shape is not None: - raise ValueError("Invalid usage: shape already specified.") - if shape is None: - # Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension - shape = (None,) - return TensorType(self.dtype, shape) +class _WithOnnxType: + """Class that implements to_type_proto.""" - def __class_getitem__(cls, shape: Optional[ShapeType]): - if cls.default_instance is None: - raise TypeError(f"{cls} does not specify a default_instance.") - # pylint erroneously flags with unsubscriptable-object if - # using subscript notation (cls.default_instance[shape]): - return cls.default_instance.__getitem__(shape) + dtype: ClassVar[DType] + shape: ClassVar[Optional[ShapeType]] = None - def to_type_proto(self) -> onnx.TypeProto: - if self.shape is None: + @classmethod + def to_type_proto(cls) -> onnx.TypeProto: + if cls.shape is None: shape = () # "FLOAT" is treated as a scalar - elif self.shape is Ellipsis: + elif cls.shape is Ellipsis: shape = None # "FLOAT[...]" is a tensor of unknown rank - elif isinstance(self.shape, tuple): - shape = self.shape # example: "FLOAT[10,20]" + elif isinstance(cls.shape, tuple): + shape = cls.shape # example: "FLOAT[10,20]" else: - shape = [self.shape] # example: "FLOAT[10]" - return onnx.helper.make_tensor_type_proto(self.dtype, shape) + shape = [cls.shape] # example: "FLOAT[10]" + return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + + +class TensorType(_WithOnnxType, type): + """ONNX Script representation of a tensor type supporting shape annotations. + + A scalar-tensor of rank 0: + :: + + tensor: FLOAT + + A tensor of unknown rank: + :: + + tensor: FLOAT[...] + + A tensor of rank 2 of unknown dimensions, with symbolic names: + :: + + tensor: FLOAT['M', 'N'] + + A tensor of rank 2 of known dimensions: + :: + + tensor: FLOAT[128, 1024] + """ + + dtype: ClassVar[DType] + shape: ClassVar[Optional[ShapeType]] = None + + def __getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]: + if cls.shape is not None: + raise ValueError("Invalid usage: shape already specified.") + if shape is None: + # Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension + shape = (None,) + _check_shape(shape) + key = (cls.dtype, shape) + shaped_type = _tensor_type_shape_cache.get(key) + if shaped_type is None: + # This calls __init_subclass__ + shaped_type = type( + cls.__name__, + (type(cls),), + dict(dtype=cls.dtype, shape=shape), + ) + _tensor_type_shape_cache[key] = shaped_type + return shaped_type -class _BuiltinTensorType: - def __init__(self, tensor_proto: onnx.TensorProto): - self.tensor_proto = tensor_proto +# pylint: disable=abstract-method,too-many-function-args +class FLOAT(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.FLOAT + shape = None - def __call__(self, cls): - cls.default_instance = TensorType(self.tensor_proto) - cls.to_type_proto = cls.default_instance.to_type_proto - return cls +class UINT8(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.UINT8 + shape = None -@_BuiltinTensorType(onnx.TensorProto.FLOAT) -class FLOAT(TensorType): - pass +class INT8(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.INT8 + shape = None -@_BuiltinTensorType(onnx.TensorProto.UINT8) -class UINT8(TensorType): - pass +class UINT16(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.UINT16 + shape = None -@_BuiltinTensorType(onnx.TensorProto.INT8) -class INT8(TensorType): - pass +class INT16(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.INT16 + shape = None -@_BuiltinTensorType(onnx.TensorProto.UINT16) -class UINT16(TensorType): - pass +class INT32(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.INT32 + shape = None -@_BuiltinTensorType(onnx.TensorProto.INT16) -class INT16(TensorType): - pass +class INT64(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.INT64 + shape = None -@_BuiltinTensorType(onnx.TensorProto.INT32) -class INT32(TensorType): - pass +class STRING(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.STRING + shape = None -@_BuiltinTensorType(onnx.TensorProto.INT64) -class INT64(TensorType): - pass +class BOOL(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.BOOL + shape = None -@_BuiltinTensorType(onnx.TensorProto.STRING) -class STRING(TensorType): - pass +class FLOAT16(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.FLOAT16 + shape = None -@_BuiltinTensorType(onnx.TensorProto.BOOL) -class BOOL(TensorType): - pass +class DOUBLE(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.DOUBLE + shape = None -@_BuiltinTensorType(onnx.TensorProto.FLOAT16) -class FLOAT16(TensorType): - pass +class UINT32(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.UINT32 + shape = None -@_BuiltinTensorType(onnx.TensorProto.DOUBLE) -class DOUBLE(TensorType): - pass +class UINT64(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.UINT64 + shape = None -@_BuiltinTensorType(onnx.TensorProto.UINT32) -class UINT32(TensorType): - pass +class COMPLEX64(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.COMPLEX64 + shape = None -@_BuiltinTensorType(onnx.TensorProto.UINT64) -class UINT64(TensorType): - pass +class COMPLEX128(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.COMPLEX128 + shape = None -@_BuiltinTensorType(onnx.TensorProto.COMPLEX64) -class COMPLEX64(TensorType): - pass +class BFLOAT16(_WithOnnxType, metaclass=TensorType): + dtype = onnx.TensorProto.BFLOAT16 + shape = None -@_BuiltinTensorType(onnx.TensorProto.COMPLEX128) -class COMPLEX128(TensorType): - pass +# pylint: enable=abstract-method,too-many-function-args -@_BuiltinTensorType(onnx.TensorProto.BFLOAT16) -class BFLOAT16(TensorType): - pass +_tensor_type_registry: dict[DType, TensorType] = { + onnx.TensorProto.FLOAT: FLOAT, + onnx.TensorProto.UINT8: UINT8, + onnx.TensorProto.INT8: INT8, + onnx.TensorProto.UINT16: UINT16, + onnx.TensorProto.INT16: INT16, + onnx.TensorProto.INT32: INT32, + onnx.TensorProto.INT64: INT64, + onnx.TensorProto.STRING: STRING, + onnx.TensorProto.BOOL: BOOL, + onnx.TensorProto.FLOAT16: FLOAT16, + onnx.TensorProto.DOUBLE: DOUBLE, + onnx.TensorProto.UINT32: UINT32, + onnx.TensorProto.UINT64: UINT64, + onnx.TensorProto.COMPLEX64: COMPLEX64, + onnx.TensorProto.COMPLEX128: COMPLEX128, + onnx.TensorProto.BFLOAT16: BFLOAT16, +} def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: diff --git a/onnxscript/test/onnx_types_test.py b/onnxscript/test/onnx_types_test.py new file mode 100644 index 0000000000..a51cd39c98 --- /dev/null +++ b/onnxscript/test/onnx_types_test.py @@ -0,0 +1,81 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# mypy: disable-error-code=misc + +"""Unit tests for the onnx_types module.""" + +import unittest + +from parameterized import parameterized + +from onnxscript.onnx_types import ( + DOUBLE, + FLOAT, + DType, + TensorType, + _tensor_type_registry, +) + + +class TestOnnxTypes(unittest.TestCase): + # def test_instantiation(self): + # with self.assertRaises(NotImplementedError): + # TensorType() + # with self.assertRaises(NotImplementedError): + # FLOAT() + # with self.assertRaises(NotImplementedError): + # FLOAT[...]() + + @parameterized.expand(_tensor_type_registry.items()) + def test_type_properties(self, dtype: DType, tensor_type: TensorType): + self.assertEqual(tensor_type.dtype, dtype) + self.assertIsNone(tensor_type.shape) + self.assertEqual(tensor_type[...].shape, ...) + self.assertEqual(tensor_type[...].dtype, dtype) + self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) + self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) + + # @parameterized.expand([(dtype,) for dtype in _tensor_type_registry]) + # def test_dtype_bound_to_subclass(self, dtype: DType): + # with self.assertRaises(ValueError): + # type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype) + + def test_shaped_doesnt_reshape(self): + with self.assertRaises(TypeError): + FLOAT[1][...] # pylint: disable=pointless-statement + + @parameterized.expand( + [ + (FLOAT, FLOAT), + (FLOAT[None], FLOAT[None]), + (FLOAT[1, 2, 3], FLOAT[1, 2, 3]), + (FLOAT[1], FLOAT[1]), + (FLOAT[...], FLOAT[Ellipsis]), + (FLOAT["M"], FLOAT["M"]), + (FLOAT["M", "N"], FLOAT["M", "N"]), + (FLOAT["M", 3, 4], FLOAT["M", 3, 4]), + ] + ) + def test_shapes_are_same_type(self, a: TensorType, b: TensorType): + self.assertIs(a, b) + + @parameterized.expand( + [ + (FLOAT[0], FLOAT[None]), + (FLOAT[1, 2], FLOAT[3, 4]), + (FLOAT[2, 1], FLOAT[1, 2]), + (FLOAT["M", "N"], FLOAT["N", "M"]), + (FLOAT, DOUBLE), + (FLOAT[1], DOUBLE[1]), + (FLOAT["X"], DOUBLE["X"]), + ] + ) + def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType): + self.assertIsNot(a, b) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/test/type_annotation_test.py b/onnxscript/test/type_annotation_test.py index 4372e86a48..b01d59a79d 100644 --- a/onnxscript/test/type_annotation_test.py +++ b/onnxscript/test/type_annotation_test.py @@ -11,7 +11,7 @@ from onnxscript.test.common import testutils -class TypeAnnotationTester(testutils.TestBase): +class TypeAnnotationTest(testutils.TestBase): def test_type_annotation(self): """Test type annotations.""" @@ -63,7 +63,7 @@ def unknown_rank(A: FLOAT[...], B: FLOAT[...]) -> FLOAT[...]: """ self.assertSameGraph(unknown_rank, unknown_rank_txt) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): FLOAT[10][20] # Invalid usage. pylint: disable=pointless-statement