-
Notifications
You must be signed in to change notification settings - Fork 64
Patch for tensor types #231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This ensures that TensorType is never actually instantiated and instead new TensorType subclasses are created (and cached) for each shape. This is a stepping stone to a fully generic TensorType where the dtype and shape are modeled into the type itself, but it appears that won't be easily realized until we can depend on Python 3.11. Importantly, this adds unit tests to ensure if we can move to a proper generic that behavior will remain the same. Fixes #219
04053e7
to
a6d5295
Compare
Codecov Report
@@ Coverage Diff @@
## abock/uninstantiated-tensor-types #231 +/- ##
=====================================================================
- Coverage 75.62% 75.61% -0.02%
=====================================================================
Files 90 90
Lines 7241 7238 -3
=====================================================================
- Hits 5476 5473 -3
Misses 1765 1765
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
I don't think this works because FLOAT is now an instance of TensorType, not a subclass. |
Will wait for |
Alternatively, we will use this and run mypy with python3.11. So mypy is happy and we can pass in literals as types even with python <3.11. Will need to fix dtype and shape. from typing import Any, Generic, TypeVar, Union, Tuple, Literal
import typing
from typing_extensions import TypeVarTuple, Unpack
DimType = Union[int, str, type(None)]
ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]
TDType = TypeVar("TDType")
TShape = TypeVarTuple("TShape")
class _DTYPE_FLOAT:
dtype = 100
class TensorType(Generic[TDType, Unpack[TShape]]):
def __class_getitem__(cls, *items: Any):
wrapped = []
for item in items[0]:
if type(item) in {int, str, type(...), type(None)}:
wrapped.append(Literal[item])
else:
wrapped.append(item)
return super().__class_getitem__(
tuple(wrapped)
)
def dtype(self) -> TDType:
return typing.get_args(self.__orig_class__)[0]
def shape(self) -> TShape:
return typing.get_args(self.__orig_class__)[1:]
class FLOAT(TensorType[_DTYPE_FLOAT, Unpack[TShape]]):
pass
print(TensorType[int, 1, 2])
print(TensorType[int, 1, 2]().dtype())
print(TensorType[int, 1, 2]().shape())
M = TypeVar("M", bound=int)
print(FLOAT[1, 2, ..., "N", M]) |
08dc10d
to
fcd664e
Compare
Now mypy can correctly check types:
It still doesn't like that that they are not generics, but I think we can disable
type-arg
andname-defined
for now.