Skip to content

Fix TensorType to return new TensorTypes for shapes instead of TensorType instances #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 81 additions & 75 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

from typing import Optional, Tuple, Union
import abc
from typing import ClassVar, Optional, Tuple, Union

import onnx
import onnx.helper

# Representations of ONNX types in ONNX Script.
# Currently restricted to tensor types.
# Example type annotations in ONNX Script.
# x : FLOAT (a scalar-tensor of rank 0)
# x : FLOAT[...] (a tensor of unknown rank)
# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
DType = onnx.TensorProto.DataType

DimType = Union[int, str, type(None)]

Expand All @@ -36,131 +32,141 @@ def check_shape(shape):
check_dim(shape)


class TensorType:
"""ONNX Script representation of a tensor type."""
tensor_type_registry: dict[DType, TensorType] = {}
_tensor_type_shape_cache: dict[DType, TensorType] = {}


class TensorType(abc.ABC):
"""ONNX Script representation of a tensor type supporting shape annotations.

A scalar-tensor of rank 0:
::

tensor: FLOAT

A tensor of unknown rank:
::

tensor: FLOAT[...]

A tensor of rank 2 of unknown dimensions, with symbolic names:
::

tensor: FLOAT['M', 'N']

A tensor of rank 2 of known dimensions:
::

default_instance: Optional["TensorType"] = None
tensor: FLOAT[128, 1024]
"""

dtype: ClassVar[DType]
shape: ClassVar[Optional[ShapeType]]

def __new__(cls):
raise NotImplementedError("TensorTypes cannot be instantiated")

def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None:
self.dtype = dtype
self.shape = shape
if shape is not None:
def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
cls.dtype = dtype
cls.shape = shape
if shape is None:
existing_cls = tensor_type_registry.get(dtype)
if existing_cls is not None:
raise ValueError(
f"Invalid usage: subclass {existing_cls!r} "
f"already defined for dtype={dtype}"
)
tensor_type_registry[dtype] = cls
else:
check_shape(shape)

def __getitem__(self, shape: Optional[ShapeType]):
if self.shape is not None:
def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]:
if cls.shape is not None:
raise ValueError("Invalid usage: shape already specified.")
if shape is None:
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
shape = (None,)
return TensorType(self.dtype, shape)

def __class_getitem__(cls, shape: Optional[ShapeType]):
if cls.default_instance is None:
raise TypeError(f"{cls} does not specify a default_instance.")
# pylint erroneously flags with unsubscriptable-object if
# using subscript notation (cls.default_instance[shape]):
return cls.default_instance.__getitem__(shape)

def to_type_proto(self) -> onnx.TypeProto:
if self.shape is None:
key = (cls.dtype, shape)
shaped_type = _tensor_type_shape_cache.get(key)
if shaped_type is None:
shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape)
_tensor_type_shape_cache[key] = shaped_type
return shaped_type

@classmethod
def to_type_proto(cls) -> onnx.TypeProto:
if cls.shape is None:
shape = () # "FLOAT" is treated as a scalar
elif self.shape is Ellipsis:
elif cls.shape is Ellipsis:
shape = None # "FLOAT[...]" is a tensor of unknown rank
elif isinstance(self.shape, tuple):
shape = self.shape # example: "FLOAT[10,20]"
elif isinstance(cls.shape, tuple):
shape = cls.shape # example: "FLOAT[10,20]"
else:
shape = [self.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(self.dtype, shape)


class _BuiltinTensorType:
def __init__(self, tensor_proto: onnx.TensorProto):
self.tensor_proto = tensor_proto

def __call__(self, cls):
cls.default_instance = TensorType(self.tensor_proto)
cls.to_type_proto = cls.default_instance.to_type_proto
return cls
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)


@_BuiltinTensorType(onnx.TensorProto.FLOAT)
class FLOAT(TensorType):
class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
pass


@_BuiltinTensorType(onnx.TensorProto.UINT8)
class UINT8(TensorType):
class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
pass


@_BuiltinTensorType(onnx.TensorProto.INT8)
class INT8(TensorType):
class INT8(TensorType, dtype=onnx.TensorProto.INT8):
pass


@_BuiltinTensorType(onnx.TensorProto.UINT16)
class UINT16(TensorType):
class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
pass


@_BuiltinTensorType(onnx.TensorProto.INT16)
class INT16(TensorType):
class INT16(TensorType, dtype=onnx.TensorProto.INT16):
pass


@_BuiltinTensorType(onnx.TensorProto.INT32)
class INT32(TensorType):
class INT32(TensorType, dtype=onnx.TensorProto.INT32):
pass


@_BuiltinTensorType(onnx.TensorProto.INT64)
class INT64(TensorType):
class INT64(TensorType, dtype=onnx.TensorProto.INT64):
pass


@_BuiltinTensorType(onnx.TensorProto.STRING)
class STRING(TensorType):
class STRING(TensorType, dtype=onnx.TensorProto.STRING):
pass


@_BuiltinTensorType(onnx.TensorProto.BOOL)
class BOOL(TensorType):
class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
pass


@_BuiltinTensorType(onnx.TensorProto.FLOAT16)
class FLOAT16(TensorType):
class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
pass


@_BuiltinTensorType(onnx.TensorProto.DOUBLE)
class DOUBLE(TensorType):
class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
pass


@_BuiltinTensorType(onnx.TensorProto.UINT32)
class UINT32(TensorType):
class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
pass


@_BuiltinTensorType(onnx.TensorProto.UINT64)
class UINT64(TensorType):
class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
pass


@_BuiltinTensorType(onnx.TensorProto.COMPLEX64)
class COMPLEX64(TensorType):
class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
pass


@_BuiltinTensorType(onnx.TensorProto.COMPLEX128)
class COMPLEX128(TensorType):
class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
pass


@_BuiltinTensorType(onnx.TensorProto.BFLOAT16)
class BFLOAT16(TensorType):
class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
pass


Expand Down
76 changes: 76 additions & 0 deletions onnxscript/test/onnx_types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# mypy: disable-error-code=misc

"""Unit tests for the onnx_types module."""
from __future__ import annotations

import unittest

from parameterized import parameterized

from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry


class TestOnnxTypes(unittest.TestCase):
def test_instantiation(self):
with self.assertRaises(NotImplementedError):
TensorType()
with self.assertRaises(NotImplementedError):
FLOAT()
with self.assertRaises(NotImplementedError):
FLOAT[...]()

@parameterized.expand(tensor_type_registry.items())
def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]):
self.assertEqual(tensor_type.dtype, dtype)
self.assertIsNone(tensor_type.shape)
self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index]
self.assertEqual(tensor_type[...].dtype, dtype) # type: ignore[index]
self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) # type: ignore[index]
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index]

@parameterized.expand([(dtype,) for dtype in tensor_type_registry])
def test_dtype_bound_to_subclass(self, dtype: DType):
with self.assertRaises(ValueError):
type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)

def test_shaped_doesnt_reshape(self):
with self.assertRaises(ValueError):
FLOAT[1][...] # pylint: disable=pointless-statement

@parameterized.expand(
[
(FLOAT, FLOAT),
(FLOAT[None], FLOAT[None]),
(FLOAT[1, 2, 3], FLOAT[1, 2, 3]),
(FLOAT[1], FLOAT[1]),
(FLOAT[...], FLOAT[Ellipsis]),
(FLOAT["M"], FLOAT["M"]),
(FLOAT["M", "N"], FLOAT["M", "N"]),
(FLOAT["M", 3, 4], FLOAT["M", 3, 4]),
]
)
def test_shapes_are_same_type(self, a: TensorType, b: TensorType):
self.assertIs(a, b)

@parameterized.expand(
[
(FLOAT[0], FLOAT[None]),
(FLOAT[1, 2], FLOAT[3, 4]),
(FLOAT[2, 1], FLOAT[1, 2]),
(FLOAT["M", "N"], FLOAT["N", "M"]),
(FLOAT, DOUBLE),
(FLOAT[1], DOUBLE[1]),
(FLOAT["X"], DOUBLE["X"]),
]
)
def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType):
self.assertIsNot(a, b)


if __name__ == "__main__":
unittest.main()