|
1 |
| -# RFC 0001 — `__torch_function__` for methods of the `torch.Tensor` class. |
| 1 | +# RFC 0001 — `__torch_function__` for methods of the `torch.Tensor` class |
2 | 2 | ## Abstract
|
3 | 3 | This RFC describes changes necessary to allow `__torch_function__` to be used by methods of `torch.Tensor` in an attempt to make subclassing more accessible to the users of the class. This entails making an API for subclass views public, and a change in the signature of `__torch_function__`.
|
4 | 4 |
|
@@ -50,53 +50,65 @@ s3.a # 1
|
50 | 50 |
|
51 | 51 | Additionally, it will provide subclass authors hooks to run whenever methods or operators are called, and to modify the result to their specific use-case, perform logging, or otherwise change the result or the action of the method.
|
52 | 52 |
|
| 53 | +## Backwards Compatibility |
| 54 | +### With PyTorch `master` as of writing |
| 55 | +PyTorch `master` pointed to commit hash `957a07ffbd13d8a805f4d718e0282efc5d2bff85` at the time of writing. Any classes implementing `__torch_function__` based on the usage in this commit hash will break completely, due to the differing signature of the protocol. However, as a release hasn't been made with `__torch_function__` in it, this is a minor-impact issue. This brings the design of `__torch_function__` more in line with NumPy's `__array_function__`, and one familiar with NumPy's protocol could transition to PyTorch's take on it without too many surprises, with the caveat that it could also receive methods rather than functions. |
| 56 | + |
| 57 | +### With NumPy |
| 58 | +As we are using a different protocol compared to NumPy `__torch_function__` vs `__array_function__`, there is no difference to the usage for those using NumPy. We propose to delay the issue of allowing the usage of Torch tensors with NumPy functions to a separate RFC. |
| 59 | + |
53 | 60 | ## Detailed Description
|
54 |
| -We propose the following signature change to `__torch_funcion__`, to make it match NumPy: [[4]] |
| 61 | +We propose the following signature change to `__torch_function__`, to make it match NumPy: [[4]] |
55 | 62 |
|
56 | 63 | ```python
|
57 | 64 | class SubTensor(torch.Tensor):
|
58 | 65 | def __torch_tensor__(self, func, types, args, kwargs):
|
59 | 66 | # Implementation here
|
60 | 67 | ```
|
61 | 68 |
|
62 |
| -The reason for this change is necessitated by the need for `super()`. If we set a requirement for `super().__array_function__` to work properly, we would need to provide an easy way for users to signal to `__array_function__` that they are calling to the next-specific implementation. The way we propose to handle this is the same as it is handled in NumPy, albiet not in the context of overriding methods, but rather, in the context of subclasses of `numpy.ndarray` or other classes that implement `__array_function__`. |
| 69 | +The reason for adding `types` to the signature is necessitated by the need for `super()`. If we set a requirement for `super().__array_function__` to work properly, we would need to provide an easy way for users to signal to `__array_function__` that they are calling to the next-specific implementation. The way we propose to handle this is the same as it is handled in NumPy, albeit not in the context of overriding methods, but rather, in the context of subclasses of `numpy.ndarray` or other classes that implement `__array_function__`. |
63 | 70 |
|
64 | 71 | To access super, one would do the following:
|
65 | 72 | ```python
|
66 | 73 | class SubTensor(torch.Tensor):
|
67 | 74 | def __torch_tensor__(self, func, types, args, kwargs):
|
68 | 75 | # Pre-processing here
|
69 |
| - val = super().__torch_function__(func, tuple(t for t in types if not issubclass(t, SubTensor), args, kwargs) |
| 76 | + val = super().__torch_function__(func, tuple(t for t in types if not issubclass(t, SubTensor)), args, kwargs) |
70 | 77 | # Post processing here
|
71 | 78 | ```
|
72 | 79 |
|
73 | 80 | This way `__torch_function__` knows the list of types to dispatch to, and it will _not_ dispatch to `SubTensor` anymore in this example.
|
74 | 81 |
|
75 |
| -We will also recommend that all `Tensor` subclasses make their own methods go through `__torch_function__` via a decorator `@torch_function_dispatch`. However, this will come with a disclaimer: They _must_ accept that their methods are subject to the same processing as any other `torch.Tensor` methods, namely, that all the processing _will necessarily go through `__torch_function__`, even if through superclasses first_. |
| 82 | +We will also recommend that all `Tensor` subclasses make their own methods go through `__torch_function__` via a decorator `@torch_function_dispatch`. This decorator was added and then removed for performance reasons, however it will be added back to allow external libraries to interface with the protocol. |
| 83 | + |
| 84 | +However, this will come with a disclaimer: They _must_ accept that their methods are subject to the same processing as any other `torch.Tensor` methods, namely, that all the processing _will necessarily go through `__torch_function__`, even if through superclasses first_. Specifically, processing may pass through a superclass's `__torch_function__` implementation before coming back to a subclass's internal implementation. |
| 85 | + |
| 86 | +We do not propose automatic marking of functions with this decorator due to the potential backwards-compatibility break it could cause, as well as the parameters that are needed in order to allow this to happen (namely the dispatcher, which isn't in our control). |
76 | 87 |
|
77 | 88 | ### Making `torch.Tensor._make_subclass` public API
|
78 |
| -`torch.Tensor._make_subclass` will be renamed to `torch.Tensor.make_subclass` and it will become public API. |
| 89 | +`torch.Tensor._make_subclass` will be renamed to `torch.Tensor.make_subclass` and it will become public API. This will allow `torch.Tensor.make_subclass` correspond to `numpy.ndarray.view`, with the difference that the latter can also handle viewing as a different `dtype` and not just as a different subclass. |
| 90 | + |
| 91 | +The reason for not choosing the name `view` is that it exists on `torch.Tensor` in an unrelated context. The semantics of this function will be the same as creating a shallow copy of the object and then changing its `__class__`. |
79 | 92 |
|
80 | 93 | ### Generic implementation of `__torch_function__`
|
81 | 94 | `torch.Tensor` will gain a generic `__torch_function__` of the following form:
|
82 | 95 |
|
83 | 96 | ```python
|
84 | 97 | class Tensor:
|
85 | 98 | def __torch_tensor__(self, func, types, args, kwargs):
|
86 |
| - if not all(issubclass(t, type(self)) for t in types): |
| 99 | + if not all(issubclass(type(self), t) for t in types): |
87 | 100 | return NotImplemented
|
88 | 101 |
|
89 |
| - if type(self) is Tensor: |
90 |
| - # Defer to internal implementation |
91 |
| - ret = func._implementation(*args, **kwargs) |
92 |
| - if isinstance(ret, Tensor): |
93 |
| - ret = Tensor.make_subclass(ret, type(self)) |
94 |
| - return ret |
| 102 | + # Defer to internal implementation |
| 103 | + ret = func._implementation(*args, **kwargs) |
| 104 | + if type(self) is not Tensor: |
| 105 | + ret = Tensor.make_subclass(ret, type(self)) |
| 106 | + return ret |
95 | 107 | ```
|
96 | 108 |
|
97 | 109 | This method matches `torch` dispatch rules, so for the most part it's possible to pretend it doesn't exist. This also has the side-effect of passing subclasses through methods, and operators (since all operators are methods).
|
98 | 110 |
|
99 |
| -This corresponds exactly to the implmentation `numpy.ndarray` gains in [[4]], except for the fact that subclasses are passed through via another internal mechanism there, as well as the fact that we are checking subclassing against `type(self)` instead of `Tensor`. This has the side-effect of ensuring unrelated class trees are not merged, which is an inconsistency in NumPy's own design. |
| 111 | +This corresponds exactly to the implementation `numpy.ndarray` gains in [[4]], except for the fact that subclasses are passed through via another internal mechanism (namely the `__array_finalize__` protocol) there, as well as the fact that we are checking subclassing against `type(self)` instead of `Tensor`. This has the side-effect of ensuring unrelated class trees are not merged, which is an inconsistency in NumPy's own design. |
100 | 112 |
|
101 | 113 |
|
102 | 114 | [1]: https://github.com/pytorch/pytorch/issues/22402 "GitHub Issue 22402 on pytorch/pytorch"
|
|
0 commit comments