Skip to content

Patch for tensor types #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed

Patch for tensor types #231

wants to merge 9 commits into from

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Dec 5, 2022

Now mypy can correctly check types:

image

It still doesn't like that that they are not generics, but I think we can disable type-arg and name-defined for now.

  Error (MYPY) type-arg
    "FLOAT" expects no type arguments, but 1 given

        28  |        self.assertSameGraph(static_shape, static_shape_txt)
        29  |
        30  |        @script()
    >>> 31  |        def symbolic_shape(A: FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"]:  # noqa: F821
        32  |            C = op.Add(A, B)
        33  |            return C
        34  |

  Error (MYPY) name-defined
    Name "N" is not defined

        28  |        self.assertSameGraph(static_shape, static_shape_txt)
        29  |
        30  |        @script()
    >>> 31  |        def symbolic_shape(A: FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"]:  # noqa: F821
        32  |            C = op.Add(A, B)
        33  |            return C
        34  |

abock and others added 3 commits December 2, 2022 16:40
This ensures that TensorType is never actually instantiated and instead
new TensorType subclasses are created (and cached) for each shape.

This is a stepping stone to a fully generic TensorType where the dtype
and shape are modeled into the type itself, but it appears that won't
be easily realized until we can depend on Python 3.11.

Importantly, this adds unit tests to ensure if we can move to a proper
generic that behavior will remain the same.

Fixes #219
@codecov
Copy link

codecov bot commented Dec 5, 2022

Codecov Report

Merging #231 (240f226) into abock/uninstantiated-tensor-types (08dc10d) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@                          Coverage Diff                          @@
##           abock/uninstantiated-tensor-types     #231      +/-   ##
=====================================================================
- Coverage                              75.62%   75.61%   -0.02%     
=====================================================================
  Files                                     90       90              
  Lines                                   7241     7238       -3     
=====================================================================
- Hits                                    5476     5473       -3     
  Misses                                  1765     1765              
Impacted Files Coverage Δ
onnxscript/irbuilder.py 78.96% <100.00%> (ø)
onnxscript/onnx_types.py 95.53% <100.00%> (-0.62%) ⬇️
onnxscript/test/onnx_types_test.py 95.83% <100.00%> (-1.31%) ⬇️
onnxscript/test/type_annotation_test.py 74.28% <100.00%> (ø)
onnxscript/test/converter_test.py 87.55% <0.00%> (+0.23%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@justinchuby justinchuby requested a review from abock December 5, 2022 21:48
@justinchuby justinchuby marked this pull request as draft December 5, 2022 22:49
@justinchuby
Copy link
Collaborator Author

I don't think this works because FLOAT is now an instance of TensorType, not a subclass.

@justinchuby
Copy link
Collaborator Author

@justinchuby
Copy link
Collaborator Author

Alternatively, we will use this and run mypy with python3.11. So mypy is happy and we can pass in literals as types even with python <3.11.

Will need to fix dtype and shape.

from typing import Any, Generic, TypeVar, Union, Tuple, Literal
import typing
from typing_extensions import TypeVarTuple, Unpack
DimType = Union[int, str, type(None)]
ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]
TDType = TypeVar("TDType")
TShape = TypeVarTuple("TShape")
class _DTYPE_FLOAT:
    dtype = 100
class TensorType(Generic[TDType, Unpack[TShape]]):
    def __class_getitem__(cls, *items: Any):
        wrapped = []
        for item in items[0]:
            if type(item) in {int, str, type(...), type(None)}:
                wrapped.append(Literal[item])
            else:
                wrapped.append(item)
        return super().__class_getitem__(
            tuple(wrapped)
        )
    def dtype(self) -> TDType:
        return typing.get_args(self.__orig_class__)[0]
    def shape(self) -> TShape:
        return typing.get_args(self.__orig_class__)[1:]
class FLOAT(TensorType[_DTYPE_FLOAT, Unpack[TShape]]):
    pass
print(TensorType[int, 1, 2])
print(TensorType[int, 1, 2]().dtype())
print(TensorType[int, 1, 2]().shape())
M = TypeVar("M", bound=int)
print(FLOAT[1, 2, ..., "N", M])

@justinchuby justinchuby force-pushed the abock/uninstantiated-tensor-types branch from 08dc10d to fcd664e Compare December 7, 2022 16:32
Base automatically changed from abock/uninstantiated-tensor-types to main December 7, 2022 16:50
@justinchuby justinchuby added the topic: discussion For discussion label Jan 19, 2023
@justinchuby justinchuby removed the request for review from abock January 19, 2023 15:01
@justinchuby justinchuby deleted the justinchu/meta-type branch January 27, 2025 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: discussion For discussion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants