Skip to content

Patch for tensor types #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __repr__(self) -> str:


class Var:
def __init__(self, varname, typeinfo, info) -> None:
def __init__(self, varname, typeinfo: ONNXType, info) -> None:
if not isinstance(varname, str):
raise ValueError(f"varname must be a string not {type(varname)!r}.")
self.name = varname
Expand Down
231 changes: 135 additions & 96 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,166 +2,205 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

from typing import Optional, Tuple, Union
from typing import ClassVar, Optional, Tuple, Union

import onnx
import onnx.helper

# Representations of ONNX types in ONNX Script.
# Currently restricted to tensor types.
# Example type annotations in ONNX Script.
# x : FLOAT (a scalar-tensor of rank 0)
# x : FLOAT[...] (a tensor of unknown rank)
# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
DType = onnx.TensorProto.DataType

DimType = Union[int, str, type(None)]


def check_dim(dim):
def _check_dim(dim):
if not isinstance(dim, (int, str, type(None))):
raise TypeError(f"Invalid dimension {dim}")


ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]


def check_shape(shape):
def _check_shape(shape):
if isinstance(shape, tuple):
for dim in shape:
check_dim(dim)
_check_dim(dim)
elif shape != Ellipsis:
check_dim(shape)

_check_dim(shape)

class TensorType:
"""ONNX Script representation of a tensor type."""

default_instance: Optional["TensorType"] = None
_tensor_type_shape_cache: dict[DType, TensorType] = {}

def __init__(self, dtype, shape: Optional[ShapeType] = None) -> None:
self.dtype = dtype
self.shape = shape
if shape is not None:
check_shape(shape)

def __getitem__(self, shape: Optional[ShapeType]):
if self.shape is not None:
raise ValueError("Invalid usage: shape already specified.")
if shape is None:
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
shape = (None,)
return TensorType(self.dtype, shape)
class _WithOnnxType:
"""Class that implements to_type_proto."""

def __class_getitem__(cls, shape: Optional[ShapeType]):
if cls.default_instance is None:
raise TypeError(f"{cls} does not specify a default_instance.")
# pylint erroneously flags with unsubscriptable-object if
# using subscript notation (cls.default_instance[shape]):
return cls.default_instance.__getitem__(shape)
dtype: ClassVar[DType]
shape: ClassVar[Optional[ShapeType]] = None

def to_type_proto(self) -> onnx.TypeProto:
if self.shape is None:
@classmethod
def to_type_proto(cls) -> onnx.TypeProto:
if cls.shape is None:
shape = () # "FLOAT" is treated as a scalar
elif self.shape is Ellipsis:
elif cls.shape is Ellipsis:
shape = None # "FLOAT[...]" is a tensor of unknown rank
elif isinstance(self.shape, tuple):
shape = self.shape # example: "FLOAT[10,20]"
elif isinstance(cls.shape, tuple):
shape = cls.shape # example: "FLOAT[10,20]"
else:
shape = [self.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(self.dtype, shape)
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)


class TensorType(_WithOnnxType, type):
"""ONNX Script representation of a tensor type supporting shape annotations.

A scalar-tensor of rank 0:
::

tensor: FLOAT

A tensor of unknown rank:
::

tensor: FLOAT[...]

A tensor of rank 2 of unknown dimensions, with symbolic names:
::

tensor: FLOAT['M', 'N']

A tensor of rank 2 of known dimensions:
::

tensor: FLOAT[128, 1024]
"""

dtype: ClassVar[DType]
shape: ClassVar[Optional[ShapeType]] = None

def __getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]:
if cls.shape is not None:
raise ValueError("Invalid usage: shape already specified.")
if shape is None:
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
shape = (None,)
_check_shape(shape)
key = (cls.dtype, shape)
shaped_type = _tensor_type_shape_cache.get(key)
if shaped_type is None:
# This calls __init_subclass__
shaped_type = type(
cls.__name__,
(type(cls),),
dict(dtype=cls.dtype, shape=shape),
)
_tensor_type_shape_cache[key] = shaped_type
return shaped_type


class _BuiltinTensorType:
def __init__(self, tensor_proto: onnx.TensorProto):
self.tensor_proto = tensor_proto
# pylint: disable=abstract-method,too-many-function-args
class FLOAT(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.FLOAT
shape = None

def __call__(self, cls):
cls.default_instance = TensorType(self.tensor_proto)
cls.to_type_proto = cls.default_instance.to_type_proto
return cls

class UINT8(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.UINT8
shape = None

@_BuiltinTensorType(onnx.TensorProto.FLOAT)
class FLOAT(TensorType):
pass

class INT8(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.INT8
shape = None

@_BuiltinTensorType(onnx.TensorProto.UINT8)
class UINT8(TensorType):
pass

class UINT16(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.UINT16
shape = None

@_BuiltinTensorType(onnx.TensorProto.INT8)
class INT8(TensorType):
pass

class INT16(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.INT16
shape = None

@_BuiltinTensorType(onnx.TensorProto.UINT16)
class UINT16(TensorType):
pass

class INT32(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.INT32
shape = None

@_BuiltinTensorType(onnx.TensorProto.INT16)
class INT16(TensorType):
pass

class INT64(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.INT64
shape = None

@_BuiltinTensorType(onnx.TensorProto.INT32)
class INT32(TensorType):
pass

class STRING(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.STRING
shape = None

@_BuiltinTensorType(onnx.TensorProto.INT64)
class INT64(TensorType):
pass

class BOOL(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.BOOL
shape = None

@_BuiltinTensorType(onnx.TensorProto.STRING)
class STRING(TensorType):
pass

class FLOAT16(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.FLOAT16
shape = None

@_BuiltinTensorType(onnx.TensorProto.BOOL)
class BOOL(TensorType):
pass

class DOUBLE(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.DOUBLE
shape = None

@_BuiltinTensorType(onnx.TensorProto.FLOAT16)
class FLOAT16(TensorType):
pass

class UINT32(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.UINT32
shape = None

@_BuiltinTensorType(onnx.TensorProto.DOUBLE)
class DOUBLE(TensorType):
pass

class UINT64(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.UINT64
shape = None

@_BuiltinTensorType(onnx.TensorProto.UINT32)
class UINT32(TensorType):
pass

class COMPLEX64(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.COMPLEX64
shape = None

@_BuiltinTensorType(onnx.TensorProto.UINT64)
class UINT64(TensorType):
pass

class COMPLEX128(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.COMPLEX128
shape = None

@_BuiltinTensorType(onnx.TensorProto.COMPLEX64)
class COMPLEX64(TensorType):
pass

class BFLOAT16(_WithOnnxType, metaclass=TensorType):
dtype = onnx.TensorProto.BFLOAT16
shape = None

@_BuiltinTensorType(onnx.TensorProto.COMPLEX128)
class COMPLEX128(TensorType):
pass

# pylint: enable=abstract-method,too-many-function-args

@_BuiltinTensorType(onnx.TensorProto.BFLOAT16)
class BFLOAT16(TensorType):
pass
_tensor_type_registry: dict[DType, TensorType] = {
onnx.TensorProto.FLOAT: FLOAT,
onnx.TensorProto.UINT8: UINT8,
onnx.TensorProto.INT8: INT8,
onnx.TensorProto.UINT16: UINT16,
onnx.TensorProto.INT16: INT16,
onnx.TensorProto.INT32: INT32,
onnx.TensorProto.INT64: INT64,
onnx.TensorProto.STRING: STRING,
onnx.TensorProto.BOOL: BOOL,
onnx.TensorProto.FLOAT16: FLOAT16,
onnx.TensorProto.DOUBLE: DOUBLE,
onnx.TensorProto.UINT32: UINT32,
onnx.TensorProto.UINT64: UINT64,
onnx.TensorProto.COMPLEX64: COMPLEX64,
onnx.TensorProto.COMPLEX128: COMPLEX128,
onnx.TensorProto.BFLOAT16: BFLOAT16,
}


def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
Expand Down
81 changes: 81 additions & 0 deletions onnxscript/test/onnx_types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# mypy: disable-error-code=misc

"""Unit tests for the onnx_types module."""

import unittest

from parameterized import parameterized

from onnxscript.onnx_types import (
DOUBLE,
FLOAT,
DType,
TensorType,
_tensor_type_registry,
)


class TestOnnxTypes(unittest.TestCase):
# def test_instantiation(self):
# with self.assertRaises(NotImplementedError):
# TensorType()
# with self.assertRaises(NotImplementedError):
# FLOAT()
# with self.assertRaises(NotImplementedError):
# FLOAT[...]()

@parameterized.expand(_tensor_type_registry.items())
def test_type_properties(self, dtype: DType, tensor_type: TensorType):
self.assertEqual(tensor_type.dtype, dtype)
self.assertIsNone(tensor_type.shape)
self.assertEqual(tensor_type[...].shape, ...)
self.assertEqual(tensor_type[...].dtype, dtype)
self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3))
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype)

# @parameterized.expand([(dtype,) for dtype in _tensor_type_registry])
# def test_dtype_bound_to_subclass(self, dtype: DType):
# with self.assertRaises(ValueError):
# type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)

def test_shaped_doesnt_reshape(self):
with self.assertRaises(TypeError):
FLOAT[1][...] # pylint: disable=pointless-statement

@parameterized.expand(
[
(FLOAT, FLOAT),
(FLOAT[None], FLOAT[None]),
(FLOAT[1, 2, 3], FLOAT[1, 2, 3]),
(FLOAT[1], FLOAT[1]),
(FLOAT[...], FLOAT[Ellipsis]),
(FLOAT["M"], FLOAT["M"]),
(FLOAT["M", "N"], FLOAT["M", "N"]),
(FLOAT["M", 3, 4], FLOAT["M", 3, 4]),
]
)
def test_shapes_are_same_type(self, a: TensorType, b: TensorType):
self.assertIs(a, b)

@parameterized.expand(
[
(FLOAT[0], FLOAT[None]),
(FLOAT[1, 2], FLOAT[3, 4]),
(FLOAT[2, 1], FLOAT[1, 2]),
(FLOAT["M", "N"], FLOAT["N", "M"]),
(FLOAT, DOUBLE),
(FLOAT[1], DOUBLE[1]),
(FLOAT["X"], DOUBLE["X"]),
]
)
def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType):
self.assertIsNot(a, b)


if __name__ == "__main__":
unittest.main()
Loading