-
Notifications
You must be signed in to change notification settings - Fork 78
RFC-0001: Add method __torch_function__ RFC. #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0e33cf2
686e4cf
049dfac
469258b
802f5d8
064a3a5
d829cb6
44915da
c668106
69feaea
199015b
d192652
51526eb
7149aa1
b4ab8a6
cb335a0
552ba37
4a82ea7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# RFC 0001 — `__torch_function__` for methods of the `torch.Tensor` class. | ||
## Abstract | ||
This RFC describes changes necessary to allow `__torch_function__` to be used by methods of `torch.Tensor` in an attempt to make subclassing more accessible to the users of the class. This entails making an API for subclass views public, and a change in the signature of `__torch_function__`. | ||
|
||
## Motivation and Scope | ||
Quoting [[1]], [[2]] and [[3]], the potential goals of this proposal are: | ||
|
||
1. Support subclassing `torch.Tensor` in Python | ||
2. Preserve `Tensor` subclasses when calling `torch` functions on them | ||
3. Preserve `Tensor` subclasses when calling `numpy` functions on them | ||
4. Use the NumPy API with PyTorch tensors (i.e. NumPy API calls dispatch to `torch` functions) | ||
5. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `Tensor` subclasses | ||
6. Reuse NumPy ufunc implementations directly from PyTorch | ||
7. Allow operations on mixed array types, e.g. `tensor + ndarray` | ||
8. Preserve `Tensor` subclasses when calling `Tensor` methods. | ||
9. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc. | ||
10. Preserve subclass attributes when using methods or views/slices/indexing. | ||
11. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators). | ||
|
||
We propose to solve this problem with the following changes to PyTorch: | ||
|
||
1. Make methods and operators of `torch.Tensor` go through the `__torch_function__` machinery. | ||
2. Add a `types` argument to `__torch_function__`, to make it match NumPy's `__array_function__`. | ||
3. Make `torch.Tensor._make_subclass` public API. | ||
4. Make `torch.Tensor` gain a generic implementation of `__torch_function__`. | ||
|
||
## Usage and Impact | ||
Once this proposal is merged, users of subclasses of `torch.Tensor` will have a much more streamlined experience. Namely, the following code example will work as-is, without the need for any further modification: | ||
|
||
```python | ||
class SubTensor(torch.Tensor): | ||
a = 1 | ||
|
||
t = SubTensor([1]) | ||
s = t.sum() | ||
isinstance(s, SubTensor) # True | ||
s.a # 1 | ||
i = t[0] | ||
isinstance(i, SubTensor) # True | ||
i.a # 1 | ||
|
||
s2 = t + torch.Tensor(1) | ||
isinstance(s2, SubTensor) # True | ||
s2.a # 1 | ||
|
||
s3 = torch.Tensor(1) + t | ||
isinstance(s3, SubTensor) # True | ||
s3.a # 1 | ||
``` | ||
|
||
Additionally, it will provide subclass authors hooks to run whenever methods or operators are called, and to modify the result to their specific use-case, perform logging, or otherwise change the result or the action of the method. | ||
|
||
## Detailed Description | ||
We propose the following signature change to `__torch_funcion__`, to make it match NumPy: [[4]] | ||
|
||
```python | ||
class SubTensor(torch.Tensor): | ||
def __torch_tensor__(self, func, types, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't say what the old type of @ngoldbaum do you remember why we didn't line up the type exactly with Numpy's type in the beginning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC that was me. Not 100% sure, but I believe I removed (or never added) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No reason afaik. That API got hammered out before I started working on the feature so I don’t know why the API is different. It wouldn’t be terribly hard to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's safe to say we should do this sooner rather than later, even if other parts of this RFC change in their design. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I wrote over in pytorch/pytorch#30730 (comment), my motivation for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I kind of disagree. |
||
# Implementation here | ||
``` | ||
|
||
The reason for this change is necessitated by the need for `super()`. If we set a requirement for `super().__array_function__` to work properly, we would need to provide an easy way for users to signal to `__array_function__` that they are calling to the next-specific implementation. The way we propose to handle this is the same as it is handled in NumPy, albiet not in the context of overriding methods, but rather, in the context of subclasses of `numpy.ndarray` or other classes that implement `__array_function__`. | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
To access super, one would do the following: | ||
```python | ||
class SubTensor(torch.Tensor): | ||
def __torch_tensor__(self, func, types, args, kwargs): | ||
# Pre-processing here | ||
val = super().__torch_function__(func, tuple(t for t in types if not issubclass(t, SubTensor), args, kwargs) | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Post processing here | ||
``` | ||
|
||
This way `__torch_function__` knows the list of types to dispatch to, and it will _not_ dispatch to `SubTensor` anymore in this example. | ||
|
||
We will also recommend that all `Tensor` subclasses make their own methods go through `__torch_function__` via a decorator `@torch_function_dispatch`. However, this will come with a disclaimer: They _must_ accept that their methods are subject to the same processing as any other `torch.Tensor` methods, namely, that all the processing _will necessarily go through `__torch_function__`, even if through superclasses first_. | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
### Making `torch.Tensor._make_subclass` public API | ||
`torch.Tensor._make_subclass` will be renamed to `torch.Tensor.make_subclass` and it will become public API. | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
### Generic implementation of `__torch_function__` | ||
`torch.Tensor` will gain a generic `__torch_function__` of the following form: | ||
|
||
```python | ||
class Tensor: | ||
def __torch_tensor__(self, func, types, args, kwargs): | ||
if not all(issubclass(t, type(self)) for t in types): | ||
return NotImplemented | ||
|
||
if type(self) is Tensor: | ||
# Defer to internal implementation | ||
ret = func._implementation(*args, **kwargs) | ||
if isinstance(ret, Tensor): | ||
ret = Tensor.make_subclass(ret, type(self)) | ||
return ret | ||
``` | ||
|
||
This method matches `torch` dispatch rules, so for the most part it's possible to pretend it doesn't exist. This also has the side-effect of passing subclasses through methods, and operators (since all operators are methods). | ||
|
||
This corresponds exactly to the implmentation `numpy.ndarray` gains in [[4]], except for the fact that subclasses are passed through via another internal mechanism there, as well as the fact that we are checking subclassing against `type(self)` instead of `Tensor`. This has the side-effect of ensuring unrelated class trees are not merged, which is an inconsistency in NumPy's own design. | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
[1]: https://github.com/pytorch/pytorch/issues/22402 "GitHub Issue 22402 on pytorch/pytorch" | ||
[2]: https://github.com/pytorch/pytorch/issues/28361#issuecomment-544520934 "Comment on GitHub Issue 28361 on pytorch/pytorch" | ||
[3]: https://github.com/pytorch/pytorch/issues/28361#issuecomment-557285807 "Comment on GitHub Issue 28361 on pytorch/pytorch" | ||
[4]: https://numpy.org/neps/nep-0018-array-function-protocol.html "NEP 18 — A dispatch mechanism for NumPy’s high level array functions" |
Uh oh!
There was an error while loading. Please reload this page.