diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 1420c7a7aecdb2..8e1f19ebf6ce6f 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -2,7 +2,7 @@ Extending PyTorch ================= In this note we'll cover ways of extending :mod:`torch.nn`, -:mod:`torch.autograd`, and writing custom C extensions utilizing our C +:mod:`torch.autograd`, :mod:`torch`, and writing custom C extensions utilizing our C libraries. Extending :mod:`torch.autograd` @@ -204,6 +204,285 @@ This is how a ``Linear`` module can be implemented:: self.in_features, self.out_features, self.bias is not None ) +Extending :mod:`torch` +---------------------- + +You can create custom types that emulate :class:`Tensor` by defining a custom +class with methods that match :class:`Tensor`. But what if you want to be able +to pass these types to functions like :func:`torch.add` in the top-level +:mod:`torch` namespace that accept :class:`Tensor` operands? + +If your custom python type defines a method named ``__torch_function__``, PyTorch +will invoke your ``__torch_function__`` implementation when an instance of your +custom class is passed to a function in the :mod:`torch` namespace. This makes +it possible to define custom implementations for any of the functions in the +:mod:`torch` namespace which your ``__torch_function__`` implementation can call, +allowing your users to make use of your custom type with existing PyTorch +workflows that they have already written for :class:`Tensor`. This works with +"duck" types that are unrelated to :class:`Tensor` as well as user-defined +subclasses of :class:`Tensor`. + +Extending :mod:`torch` with a :class:`Tensor`-like type +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: This functionality is inspired by the NumPy ``__array_function__`` + protocol. See `the NumPy documentation + `_ + and `NEP-0018 + `_ for + more details. + +To make this concrete, let's begin with a simple example that illustrates the +API dispatch mechanism. We'll create a custom type that represents a 2D scalar +tensor, parametrized by the order ``N`` and value along the diagonal entries, +``value``:: + + class ScalarTensor(object): + def __init__(self, N, value): + self._N = N + self._value = value + + def __repr__(self): + return "DiagonalTensor(N={}, value={})".format(self._N, self._value) + + def tensor(self): + return self._value * torch.eye(self._N) + +This first iteration of the design isn't very useful. The main functionality of +``ScalarTensor`` is to provide a more compact string representation of a scalar +tensor than in the base tensor class:: + + >>> d = ScalarTensor(5, 2) + >>> d + ScalarTensor(N=5, value=2) + >>> d.tensor() + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 0., 0.], + [0., 0., 0., 2., 0.], + [0., 0., 0., 0., 2.]]) + +If we try to use this object with the :mod:`torch` API, we will run +into issues:: + + >>> import torch + >>> torch.mean(d) + TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor + +Adding a ``__torch_function__`` implementation to ``ScalarTensor`` makes it +possible for the above operation to succeed. Let's re-do our implementation, +this time adding a ``__torch_function__`` implementation:: + + HANDLED_FUNCTIONS = {} + class ScalarTensor(object): + def __init__(self, N, value): + self._N = N + self._value = value + + def __repr__(self): + return "DiagonalTensor(N={}, value={})".format(self._N, self._value) + + def tensor(self): + return self._value * torch.eye(self._N) + + def __torch_function__(self, func, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in HANDLED_FUNCTIONS: + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + +The ``__torch_function__`` method takes three arguments: ``func``, a reference to +the torch API function that is being overrided, ``args``, the tuple of arguments +passed to the function, and ``kwargs``, the dict of keyword arguments passed to +the function. It uses a global dispatch stable named ``HANDLED_FUNCTIONS`` to +store custom implementations. The keys of this dictionary are functions in the +``torch`` namespace and the values are implementations for ``ScalarTensor``. + +.. note:: Using a global dispatch table is not a mandated part of the + ``__torch_function__`` API, it is just a useful design pattern for + structuring your override implementations. + +This class definition isn't quite enough to make ``torch.mean`` do the right +thing when we pass it a ``ScalarTensor`` -- we also need to define an +implementation for ``torch.mean`` for ``ScalarTensor`` operands and add the +implementation to the ``HANDLED_FUNCTIONS`` dispatch table dictionary. One way +of doing this is to define a decorator:: + + import functools + def implements(torch_function): + """Register a torch function override for ScalarTensor""" + @functools.wraps(torch_function) + def decorator(func): + HANDLED_FUNCTIONS[torch_function] = func + return func + return decorator + +which can be applied to the implementation of our override:: + + @implements(torch.mean) + def mean(input): + return float(input._value) / input._N + +With this change we can now use ``torch.mean`` with ``ScalarTensor``:: + + >>> d = ScalarTensor(5, 2) + >>> torch.mean(d) + 0.4 + +Of course ``torch.mean`` is an example of the simplest kind of function to +override since it only takes one operand. We can use the same machinery to +override a function that takes more than one operand, any one of which might be +a tensor or tensor-like that defines ``__torch_function__``, for example for +:func:`torch.add`:: + + def ensure_tensor(data): + if isinstance(data, ScalarTensor): + return data.tensor() + return torch.as_tensor(data) + + @implements(torch.add) + def add(input, other): + try: + if input._N == other._N: + return ScalarTensor(input._N, input._value + other._value) + else: + raise ValueError("Shape mismatch!") + except AttributeError: + return torch.add(ensure_tensor(input), ensure_tensor(other)) + +This version has a fast path for when both operands are ``ScalarTensor`` +instances and also a slower path which degrades to converting the data to +tensors when either operand is not a ``ScalarTensor``. That makes the override +function correctly when either operand is a ``ScalarTensor`` or a regular +:class:`Tensor`:: + + >>> s = ScalarTensor(2, 2) + >>> torch.add(s, s) + DiagonalTensor(N=2, value=4) + >>> t = torch.tensor([[1, 1,], [1, 1]]) + >>> torch.add(s, t) + tensor([[3., 1.], + [1., 3.]]) + +Note that our implementation of ``add`` does not take ``alpha`` or ``out`` as +keyword arguments like :func:`torch.add` does:: + + >>> torch.add(s, s, alpha=2) + TypeError: add() got an unexpected keyword argument 'alpha' + +For speed and flexibility the ``__torch_function__`` dispatch mechanism does not +check that the signature of an override function matches the signature of the +function being overrided in the :mod:`torch` API. For some applications ignoring +optional arguments would be fine but to ensure full compatibility with +:class:`Tensor`, user implementations of torch API functions should take care to +exactly emulate the API of the function that is being overrided. + +Functions in the :mod:`torch` API that do not have explicit overrides will +return ``NotImplemented`` from ``__torch_function__``. If all operands with +``__torch_function__`` defined on them return ``NotImplemented``, PyTorch will +raise a ``TypeError``. This means that most of the time operations that do not +have explicit overrides for a type will raise a ``TypeError`` when an instance +of such a type is passed:: + + >>> torch.mul(s, 3) + TypeError: no implementation found for 'torch.mul' on types that + implement __torch_function__: [ScalarTensor] + +In practice this means that if you would like to implement your overrides using +a ``__torch_function__`` implementation along these lines, you will need to +explicitly implement the full :mod:`torch` API or the entire subset of the API +that you care about for your use case. This may be a tall order as the full +:mod:`torch` API is quite extensive. + +Another option is to not return ``NotImplemented`` for operations that are not +handled but to instead pass a :class:`Tensor` to the original :mod:`torch` +function when no override is available. For example, if we change our +implementation of ``__torch_function__`` for ``ScalarTensor`` to the one below:: + + def __torch_function__(self, func, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in HANDLED_FUNCTIONS: + args = [a.tensor() if hasattr(a, 'tensor') else a for a in args] + return func(*args, **kwargs) + return HANDLED_FUNCTIONS[func](*args, **kwargs) + +Then :func:`torch.mul` will work correctly, although the return type will always +be a :class:`Tensor` rather than a :class:`ScalarTensor`, even if both operands +are :class:`ScalarTensor` instances:: + + >>> s = ScalarTensor(2, 2) + >>> torch.mul(s, s) + tensor([[4., 0.], + [0., 4.]]) + +Also see the ``MetadataTensor`` example below for another variation on this +pattern but instead always returns a ``MetadataTensor`` to propagate metadata +through operations in the :mod:`torch` API. + +Extending :mod:`torch` with a :class:`Tensor` wrapper type +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Another useful case is a type that wraps a :class:`Tensor`, either as an +attribute or via subclassing. Below we implement a special case of this sort of +type, a ``MetadataTensor`` that attaches a dictionary of metadata to a +:class:`Tensor` that is propagated through :mod:`torch` operations. Since this +is a generic sort of wrapping for the full :mod:`torch` API, we do not need to +individually implement each override so we can make the ``__torch_function__`` +implementation more permissive about what operations are allowed:: + + class MetadataTensor(object): + def __init__(self, data, metadata=None, **kwargs): + self._t = torch.as_tensor(data, **kwargs) + self._metadata = metadata + + def __repr__(self): + return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t) + + def __torch_function__(self, func, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + args = [a._t if hasattr(a, '_t') else a for a in args] + ret = func(*args, **kwargs) + return MetadataTensor(ret, metadata=self._metadata) + +This simple implementation won't necessarily work with every function in the +:mod:`torch` API but it is good enough to capture most common operations:: + + >>> metadata = {'owner': 'Ministry of Silly Walks'} + >>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata) + >>> t = torch.tensor([[1, 2], [1, 2]]]) + >>> torch.add(t, m) + Metadata: + {'owner': 'Ministry of Silly Walks'} + + data: + tensor([[2, 4], + [4, 6]]) + >>> torch.mul(t, m) + Metadata: + {'owner': 'Ministry of Silly Walks'} + + data: + tensor([[1, 4], + [3, 8]]) + +Operations on multiple types that define ``__torch_function__`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is possible to use the torch API with multiple distinct types that each have +a ``__torch_function__`` implementation, but special care must be taken. In such +a case the rules are: + +* The dispatch operation gathers all distinct implementations of + ``__torch_function__`` for each operand and calls them in order: subclasses + before superclasses, and otherwise left to right in the operator expression. +* If any value other than ``NotImplemented`` is returned, that value is + returned as the result. Implementations can register that they do not + implement an operation by returning ``NotImplemented``. +* If all of the ``__torch_function__`` implementations return + ``NotImplemented``, PyTorch raises a ``TypeError``. Writing custom C++ extensions ----------------------------- diff --git a/test/onnx/expect/TestOperators.test_frobenius_norm.expect b/test/onnx/expect/TestOperators.test_frobenius_norm.expect index ddd73648cdaf42..4bb058a66bbb50 100644 --- a/test/onnx/expect/TestOperators.test_frobenius_norm.expect +++ b/test/onnx/expect/TestOperators.test_frobenius_norm.expect @@ -3,8 +3,8 @@ producer_name: "pytorch" producer_version: "1.3" graph { node { - input: "x" - input: "x" + input: "0" + input: "0" output: "1" name: "Mul_0" op_type: "Mul" @@ -34,7 +34,7 @@ graph { } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_meshgrid.expect b/test/onnx/expect/TestOperators.test_meshgrid.expect index 2f983d948e1b8d..179ba9c478dc30 100644 --- a/test/onnx/expect/TestOperators.test_meshgrid.expect +++ b/test/onnx/expect/TestOperators.test_meshgrid.expect @@ -17,7 +17,7 @@ graph { } } node { - input: "x" + input: "0" input: "3" output: "4" name: "Reshape_1" @@ -38,7 +38,7 @@ graph { } } node { - input: "y" + input: "1" input: "5" output: "6" name: "Reshape_3" @@ -59,7 +59,7 @@ graph { } } node { - input: "z" + input: "2" input: "7" output: "8" name: "Reshape_5" @@ -221,7 +221,7 @@ graph { } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 1 @@ -234,7 +234,7 @@ graph { } } input { - name: "y" + name: "1" type { tensor_type { elem_type: 1 @@ -247,7 +247,7 @@ graph { } } input { - name: "z" + name: "2" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_unique.expect b/test/onnx/expect/TestOperators.test_unique.expect index 762dcd6c788e30..ecc25d657d3ef8 100644 --- a/test/onnx/expect/TestOperators.test_unique.expect +++ b/test/onnx/expect/TestOperators.test_unique.expect @@ -3,7 +3,7 @@ producer_name: "pytorch" producer_version: "1.3" graph { node { - input: "x" + input: "0" output: "1" output: "2" output: "3" @@ -23,7 +23,7 @@ graph { } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 1 diff --git a/test/run_test.py b/test/run_test.py index 9ab527ebfa0ec1..84ebbf28a7297e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -62,6 +62,7 @@ 'type_promotion', 'jit_disabled', 'function_schema', + 'overrides', ] # skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn diff --git a/test/test_overrides.py b/test/test_overrides.py new file mode 100644 index 00000000000000..fb25da36a2c538 --- /dev/null +++ b/test/test_overrides.py @@ -0,0 +1,947 @@ +import torch +import numpy as np +import unittest +import inspect +import pprint +import functools + +from common_utils import TestCase +from torch._overrides import torch_function_dispatch + +# The functions below simulate the pure-python torch functions in the +# torch.functional namespace. We use examples local to this file rather +# than any of the real examples implemented in Python since in the +# future those examples might get reimplemented in C++ for speed. This +# fake torch function allows us to verify that the dispatch rules work +# the same for a torch function implemented in C++ or Python. + +def foo_dispatcher(a, b, c=None): + return (a, b, c) + +@torch_function_dispatch(foo_dispatcher) +def foo(a, b, c=None): + """A function multiple arguments and an optional argument""" + if c: + return a + b + c + return a + b + +def bar_dispatcher(a): + return (a,) + +@torch_function_dispatch(bar_dispatcher) +def bar(a): + """A function with one argument""" + return a + +def baz_dispatcher(a, b): + return (a, b) + +@torch_function_dispatch(baz_dispatcher) +def baz(a, b): + """A function with multiple arguments""" + return a + b + +def quux_dispatcher(a): + return (a,) + +@torch_function_dispatch(quux_dispatcher) +def quux(a): + """Used to test that errors raised in user implementations get propagated""" + return a + +# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that +# DiagonalTensor.__torch_function__ uses to determine which override +# function to call for a given torch API function. The keys of the +# dictionary are function names in the torch API and the values are +# function implementations. Implementations are added to +# HANDLED_FUNCTION_DIAGONAL by decorating a python function with +# implements_diagonal. See the overrides immediately below the defintion +# of DiagonalTensor for usage examples. +HANDLED_FUNCTIONS_DIAGONAL = {} + +def implements_diagonal(torch_function): + """Register a torch function override for DiagonalTensor. + + This decorator takes a function in the torch API as a + parameter. Applying this decorator to a function adds that function + as the registered override for the torch function passed as a + parameter to the decorator. See DiagonalTensor.__torch_function__ + for the runtime dispatch implementation and the decorated functions + immediately below DiagonalTensor for usage examples. + """ + @functools.wraps(torch_function) + def decorator(func): + HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func + return func + return decorator + +class DiagonalTensor(object): + """A class with __torch_function__ and a specific diagonal representation + + This class has limited utility and is mostly useful for verifying that the + dispatch mechanism works as expected. It is based on the `DiagonalArray + example`_ in the NumPy documentation. + + Note that this class does *not* inherit from ``torch.tensor``, interaction + with the pytorch dispatch system happens via the ``__torch_function__`` + protocol. + + ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has + diagonal entries set to *value* and all other entries set to zero. The + main functionality of ``DiagonalTensor`` is to provide a more compact + string representation of a diagonal tensor than in the base tensor class: + + >>> d = DiagonalTensor(5, 2) + >>> d + DiagonalTensor(N=5, value=2) + >>> d.tensor() + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 0., 0.], + [0., 0., 0., 2., 0.], + [0., 0., 0., 0., 2.]]) + + Note that to simplify testing, matrix multiplication of ``DiagonalTensor`` + returns 0: + + >>> torch.mm(d, d) + 0 + + .. _DiagonalArray example: + https://numpy.org/devdocs/user/basics.dispatch.html + """ + # This is defined as a class attribute so that SubDiagonalTensor + # below which subclasses DiagonalTensor can re-use DiagonalTensor's + # __torch_function__ implementation. + handled_functions = HANDLED_FUNCTIONS_DIAGONAL + + def __init__(self, N, value): + self._N = N + self._i = value + + def __repr__(self): + return "DiagonalTensor(N={}, value={})".format(self._N, self._i) + + def __array__(self): + return self._i * np.eye(self._N) + + def tensor(self): + return self._i * torch.eye(self._N) + + def __torch_function__(self, func, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in self.handled_functions: + return NotImplemented + return self.handled_functions[func](*args, **kwargs) + + def __eq__(self, other): + if type(other) is type(self): + if self._N == other._N and self._i == other._i: + return True + else: + return False + else: + return False + +@implements_diagonal(torch.mean) +def mean(mat): + return float(mat._i) / mat._N + +@implements_diagonal(torch.mm) +def diagonal_mm(mat1, mat2): + return 0 + +@implements_diagonal(torch.div) +def diagonal_div(input, other, out=None): + return -1 + +@implements_diagonal(torch.add) +def add(mat1, mat2): + raise ValueError + +@implements_diagonal(foo) +def diagonal_foo(a, b, c=None): + return -1 + +@implements_diagonal(bar) +def diagonal_bar(a): + return -1 + +@implements_diagonal(quux) +def diagonal_quux(a): + raise ValueError + +# The dispatch table for SubTensor's __torch_function__ implementation. +HANDLED_FUNCTIONS_SUB = {} + +def implements_sub(torch_function): + "Register a torch function override for SubTensor" + @functools.wraps(torch_function) + def decorator(func): + HANDLED_FUNCTIONS_SUB[torch_function] = func + return func + return decorator + +class SubTensor(torch.Tensor): + """A subclass of torch.Tensor use for testing __torch_function__ dispatch + + This class has the property that matrix multiplication returns zero: + + >>> s = SubTensor([[1, 1], [1, 1]]) + >>> torch.mm(s, s) + 0 + >>> t = torch.tensor([[1, 1], [1, 1]]) + >>> torch.mm(s, t) + 0 + >>> torch.mm(t, s) + 0 + >>> torch.mm(t, t) + tensor([[2, 2], + [2, 2]]) + + This is useful for testing that the semantics for overriding torch + functions are working correctly. + """ + def __torch_function__(self, func, args=(), kwargs=None): + if(kwargs is None): + kwargs = {} + + if func not in HANDLED_FUNCTIONS_SUB: + return NotImplemented + return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs) + +@implements_sub(torch.mean) +def sub_mean(mat): + return 0 + +@implements_sub(torch.mm) +def sub_mm(mat1, mat2): + return -1 + +@implements_sub(torch.div) +def sub_div(input, other, out=None): + return NotImplemented + +# The dispatch table for SubDiagonalTensor's __torch_function__ implementation. +HANDLED_FUNCTIONS_SUB_DIAGONAL = {} + +def implements_sub_diagonal(torch_function): + "Register a torch function override for SubDiagonalTensor" + @functools.wraps(torch_function) + def decorator(func): + HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func + return func + return decorator + +class SubDiagonalTensor(DiagonalTensor): + """A subclass of ``DiagonalTensor`` to test custom dispatch + + This class tests semantics for defining ``__torch_function__`` on a + subclass of another class that defines ``__torch_function__``. The + only difference compared with the superclass is that this class + provides a slightly different repr as well as custom implementations + of ``mean`` and ``mm``, scaling the mean by a factor of 10 and + returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does. + """ + handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL + + def __repr__(self): + return "SubDiagonalTensor(N={}, value={})".format(self._N, self._i) + + +@implements_sub_diagonal(torch.mean) +def sub_diagonal_mean(mat): + return 10 * float(mat._i) / mat._N + +@implements_sub_diagonal(bar) +def sub_diagonal_bar(mat): + return 0 + +@implements_sub_diagonal(torch.mm) +def sub_diagonal_mm(mat1, mat2): + return 1 + +@implements_sub_diagonal(torch.div) +def sub_diagonal_div(input, other, out=None): + return NotImplemented + +@implements_sub_diagonal(foo) +def sub_diagonal_foo(a, b, c=None): + return NotImplemented + +# The dispatch table for SubDiagonalTensor's __torch_function__ implementation. +HANDLED_FUNCTIONS_TENSOR_LIKE = {} + +def implements_tensor_like(torch_function): + "Register a torch function override for TensorLike" + @functools.wraps(torch_function) + def decorator(func): + HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func + return func + return decorator + +# Functions that are publicly available in the torch API but cannot be +# overrided with __torch_function__ (usually because none of their +# arguments are tensors or tensor-likes) need an entry in this tuple. + +IGNORED_TORCH_FUNCTIONS = ( + torch.typename, + torch.is_tensor, + torch.is_storage, + torch.set_default_tensor_type, + torch.set_rng_state, + torch.get_rng_state, + torch.manual_seed, + torch.initial_seed, + torch.seed, + torch.save, + torch.load, + torch.set_printoptions, + torch.fork, + torch.get_default_dtype, + torch.get_num_interop_threads, + torch.get_num_threads, + torch.import_ir_module, + torch.import_ir_module_from_buffer, + torch.is_anomaly_enabled, + torch.is_grad_enabled, + torch.merge_type_from_type_comment, + torch.parse_ir, + torch.parse_schema, + torch.parse_type_comment, + torch.set_anomaly_enabled, + torch.set_flush_denormal, + torch.set_num_interop_threads, + torch.set_num_threads, + torch.wait, + torch.as_tensor, + torch.from_numpy, + torch.get_device, + torch.tensor, + torch.default_generator, + torch.has_cuda, + torch.has_cudnn, + torch.has_lapack, + torch.cpp, + torch.device, + torch.dtype, + torch.finfo, + torch.has_mkl, + torch.has_mkldnn, + torch.has_openmp, + torch.iinfo, + torch.memory_format, + torch.qscheme, + torch.set_grad_enabled, + torch.no_grad, + torch.enable_grad, + torch.layout, + torch.align_tensors, + torch.arange, + torch.as_strided, + torch.bartlett_window, + torch.blackman_window, + torch.can_cast, + torch.cudnn_affine_grid_generator, + torch.cudnn_batch_norm, + torch.cudnn_convolution, + torch.cudnn_convolution_transpose, + torch.cudnn_grid_sampler, + torch.cudnn_is_acceptable, + torch.empty, + torch.empty_strided, + torch.eye, + torch.from_file, + torch.full, + torch.hamming_window, + torch.hann_window, + torch.linspace, + torch.logspace, + torch.mkldnn_adaptive_avg_pool2d, + torch.mkldnn_convolution, + torch.mkldnn_convolution_backward_weights, + torch.mkldnn_max_pool2d, + torch.ones, + torch.promote_types, + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.range, + torch.sparse_coo_tensor, + torch.zeros, +) + +# Every function in the torch API that can be overriden needs an entry +# in this tuple. +# +# Each element is itself a two-element tuple. The first entry is the +# function in the torch API to override, the second is a lambda function +# that returns -1 whose non-default positional arguments match the +# signature of the torch function in the first entry. +# +# The machinery below will call this function on a TensorLike or set of +# TensorLike objects that match the API of the lambda function and +# verify that we get -1 back from the torch API, verifying that +# __torch_function__ dispatch works correctly for the torch function. +TENSOR_LIKE_TORCH_IMPLEMENTATIONS = ( + (torch.abs, lambda input, out=None: -1), + (torch.adaptive_avg_pool1d, lambda input, output_size: -1), + (torch.adaptive_max_pool1d, lambda inputs, output_size: -1), + (torch.acos, lambda input, out=None: -1), + (torch.add, lambda input, other, out=None: -1), + (torch.addbmm, lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1), + (torch.addcdiv, lambda input, value, tensor1, tensor2, out=None: -1), + (torch.addcmul, lambda input, value, tensor1, tensor2, out=None: -1), + (torch.addmm, lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1), + (torch.addmv, lambda input, mat, vec, beta=1, alpha=1, out=None: -1), + (torch.addr, lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1), + (torch.affine_grid_generator, lambda theta, size, align_corners: -1), + (torch.all, lambda input: -1), + (torch.allclose, lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1), + (torch.alpha_dropout, lambda input, p, train, inplace=False: -1), + (torch.angle, lambda input, out=None: -1), + (torch.any, lambda input, dim, keepdim=False, out=None: -1), + (torch.argmax, lambda input: -1), + (torch.argmin, lambda input: -1), + (torch.argsort, lambda input: -1), + (torch.asin, lambda input, out=None: -1), + (torch.atan, lambda input, out=None: -1), + (torch.atan2, lambda input, other, out=None: -1), + (torch.avg_pool1d, lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1), + (torch.baddbmm, lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1), + (torch.batch_norm, lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1), + (torch.batch_norm_backward_elemt, lambda grad_out, input, mean, invstd, weight, mean_dy, mean_dy_xmu: -1), + (torch.batch_norm_backward_reduce, lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1), + (torch.batch_norm_elemt, lambda input, weight, bias, mean, invstd, eps: -1), + (torch.batch_norm_gather_stats, lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1), + (torch.batch_norm_gather_stats_with_counts, lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1), + (torch.batch_norm_stats, lambda input, eps: -1), + (torch.batch_norm_update_stats, lambda input, running_mean, running_var, momentum: -1), + (torch.bernoulli, lambda input, generator=None, out=None: -1), + (torch.bilinear, lambda input1, input2, weight, bias: -1), + (torch.binary_cross_entropy_with_logits, lambda input, target, weight=None, size_average=None, reduce=None, reduction='mean', + pos_weight=None: -1), + (torch.bincount, lambda input, weights=None, minlength=0: -1), + (torch.bitwise_not, lambda input, out=None: -1), + (torch.bitwise_xor, lambda input, other, out=None: -1), + (torch.bmm, lambda input, mat2, out=None: -1), + (torch.broadcast_tensors, lambda *tensors: -1), + (torch.cartesian_prod, lambda *tensors: -1), + (torch.cat, lambda tensors, dim=0, out=None: -1), + (torch.cdist, lambda x1, c2, p=2, compute_mode=None: -1), + (torch.ceil, lambda input, out=None: -1), + (torch.celu, lambda input, alhpa=1., inplace=False: -1), + (torch.chain_matmul, lambda *matrices: -1), + (torch.cholesky, lambda input, upper=False, out=None: -1), + (torch.cholesky_inverse, lambda input, upper=False, out=None: -1), + (torch.cholesky_solve, lambda input1, input2, upper=False, out=None: -1), + (torch.chunk, lambda input, chunks, dim=0: -1), + (torch.clamp, lambda input, min, max, out=None: -1), + (torch.clamp_min, lambda input, min, out=None: -1), + (torch.clamp_max, lambda input, max, out=None: -1), + (torch.clone, lambda input: -1), + (torch.combinations, lambda input, r=2, with_replacement=False: -1), + (torch.conj, lambda input, out=None: -1), + (torch.constant_pad_nd, lambda input, pad, value=0: -1), + (torch.conv1d, lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1), + (torch.conv2d, lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1), + (torch.conv3d, lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1), + (torch.convolution, lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1), + (torch.conv_tbc, lambda input, weight, bias, pad=0: -1), + (torch.conv_transpose1d, lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1), + (torch.conv_transpose2d, lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1), + (torch.conv_transpose3d, lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1), + (torch.cos, lambda input, out=None: -1), + (torch.cosine_embedding_loss, lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1), + (torch.cosh, lambda input, out=None: -1), + (torch.cosine_similarity, lambda x1, x2, dim=1, eps=1e-8: -1), + (torch.cross, lambda input, other, dim=-1, out=None: -1), + (torch.ctc_loss, lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False: -1), + (torch.cumprod, lambda input, dim, out=None, dtype=None: -1), + (torch.cumsum, lambda input, dim, out=None, dtype=None: -1), + (torch.dequantize, lambda input: -1), + (torch.det, lambda input: -1), + (torch.detach, lambda input: -1), + (torch.diag, lambda input, diagonal=0, out=None: -1), + (torch.diag_embed, lambda input, diagonal=0, out=None: -1), + (torch.diagflat, lambda input, offset=0: -1), + (torch.diagonal, lambda input, offset=0, dim1=0, dim2=1: -1), + (torch.digamma, lambda input, out=None: -1), + (torch.dist, lambda input, other, p=2: -1), + (torch.div, lambda input, other, out=None: -1), + (torch.dot, lambda mat1, mat2: -1), + (torch.dropout, lambda input, p, train, inplace=False: -1), + (torch.dsmm, lambda input, mat2: -1), + (torch.hsmm, lambda mat1, mat2: -1), + (torch.eig, lambda input, eigenvectors=False, out=None: -1), + (torch.einsum, lambda equation, *operands: -1), + (torch.einsum, lambda equation, *operands: -1), + (torch.embedding, lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, + sparse=False: -1), + (torch.embedding_bag, lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, + mode='mean', sparse=False, per_sample_weights=None: -1), + (torch.empty_like, lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1), + (torch.eq, lambda input, other, out=None: -1), + (torch.equal, lambda input, other: -1), + (torch.erf, lambda input, out=None: -1), + (torch.erfc, lambda input, out=None: -1), + (torch.erfinv, lambda input, out=None: -1), + (torch.exp, lambda input, out=None: -1), + (torch.expm1, lambda input, out=None: -1), + (torch.fake_quantize_per_channel_affine, lambda input, scale, zero_point, axis, quant_min, quant_max: -1), + (torch.fake_quantize_per_tensor_affine, lambda input, scale, zero_point, quant_min, quant_max: -1), + (torch.fbgemm_linear_fp16_weight, lambda input, packed_weight, bias: -1), + (torch.fbgemm_linear_fp16_weight_fp32_activation, lambda input, packed_weight, bias: -1), + (torch.fbgemm_linear_int8_weight, lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1), + (torch.fbgemm_linear_int8_weight_fp32_activation, lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, + bias: -1), + (torch.fbgemm_linear_quantize_weight, lambda input: -1), + (torch.fbgemm_pack_gemm_matrix_fp16, lambda input: -1), + (torch.fbgemm_pack_quantized_matrix, lambda input, K, N: -1), + (torch.feature_alpha_dropout, lambda input, p, train: -1), + (torch.feature_dropout, lambda input, p, train: -1), + (torch.fft, lambda input, signal_ndim, normalized=False: -1), + (torch.flatten, lambda input, start_dim=0, end_dim=-1: -1), + (torch.flip, lambda input, dims: -1), + (torch.frobenius_norm, lambda input, dim=None, keepdim=False, out=None: -1), + (torch.floor, lambda input, out=None: -1), + (torch.fmod, lambda input, other, out=None: -1), + (torch.frac, lambda input, out=None: -1), + (torch.full_like, lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1), + (torch.gather, lambda input, dim, index, out=None, sparse_grad=False: -1), + (torch.ge, lambda input, other, out=None: -1), + (torch.geqrf, lambda input, out=None: -1), + (torch.ger, lambda input, vec2, out=None: -1), + (torch.grid_sampler, lambda input, grid, interpolation_mode, padding_mode, align_corners: -1), + (torch.grid_sampler_2d, lambda input, grid, interpolation_mode, padding_mode, align_corners: -1), + (torch.grid_sampler_3d, lambda input, grid, interpolation_mode, padding_mode, align_corners: -1), + (torch.group_norm, lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1), + (torch.gru, lambda input, hx, params, has_biases, num_layers, gropout, train, bidirectional, batch_first: -1), + (torch.gru_cell, lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1), + (torch.gt, lambda input, other, out=None: -1), + (torch.hardshrink, lambda input, lambd=0.5: -1), + (torch.hinge_embedding_loss, lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1), + (torch.histc, lambda input, bins=100, min=0, max=0, out=None: -1), + (torch.hspmm, lambda mat1, mat2, out=None: -1), + (torch.ifft, lambda input, signal_ndim, normalized=False: -1), + (torch.imag, lambda input, out=None: -1), + (torch.index_add, lambda input, dim, index, source: -1), + (torch.index_copy, lambda input, dim, index, source: -1), + (torch.index_put, lambda input, indices, values, accumulate=False: -1), + (torch.index_select, lambda input, dim, index, out=None: -1), + (torch.index_fill, lambda input, dim, index, value: -1), + (torch.isfinite, lambda tensor: -1), + (torch.isinf, lambda tensor: -1), + (torch.instance_norm, lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1), + (torch.int_repr, lambda input: -1), + (torch.inverse, lambda input, out=None: -1), + (torch.irfft, lambda input, signal_ndim, normalized=False, onesided=True, signal_sizes=None: -1), + (torch.is_complex, lambda input: -1), + (torch.is_distributed, lambda input: -1), + (torch.is_floating_point, lambda input: -1), + (torch.is_nonzero, lambda input: -1), + (torch.is_same_size, lambda input, other: -1), + (torch.is_signed, lambda input: -1), + (torch.isclose, lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1), + (torch.isnan, lambda input: -1), + (torch.kl_div, lambda input, target, size_average=None, reduce=None, reduction='mean': -1), + (torch.kthvalue, lambda input, k, dim=None, keepdim=False, out=None: -1), + (torch.layer_norm, lambda input, normalized_shape, weight=None, bias=None, esp=1e-05: -1), + (torch.le, lambda input, other, out=None: -1), + (torch.lerp, lambda input, end, weight, out=None: -1), + (torch.lgamma, lambda input, out=None: -1), + (torch.log, lambda input, out=None: -1), + (torch.log_softmax, lambda input, dim, dtype: -1), + (torch.log10, lambda input, out=None: -1), + (torch.log1p, lambda input, out=None: -1), + (torch.log2, lambda input, out=None: -1), + (torch.logdet, lambda input: -1), + (torch.logical_not, lambda input, out=None: -1), + (torch.logical_xor, lambda input, other, out=None: -1), + (torch.logsumexp, lambda input, names, keepdim, out=None: -1), + (torch.lstm, lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1), + (torch.lstm_cell, lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1), + (torch.lstsq, lambda input, A, out=None: -1), + (torch.lt, lambda input, other, out=None: -1), + (torch.lu, lambda A, pivot=True, get_infos=False, out=None: -1), + (torch.lu_solve, lambda input, LU_data, LU_pivots, out=None: -1), + (torch.margin_ranking_loss, lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1), + (torch.masked_fill, lambda input, mask, value: -1), + (torch.masked_scatter, lambda input, mask, source: -1), + (torch.masked_select, lambda input, mask, out=None: -1), + (torch.matmul, lambda input, other, out=None: -1), + (torch.matrix_power, lambda input, n: -1), + (torch.matrix_rank, lambda input, tol=None, symmetric=False: -1), + (torch.max, lambda input, out=None: -1), + (torch.max_pool1d, lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1), + (torch.max_pool2d, lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1), + (torch.max_pool3d, lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1), + (torch.max_pool1d_with_indices, lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, + ceil_mode=False: -1), + (torch.mean, lambda input: -1), + (torch.median, lambda input: -1), + (torch.meshgrid, lambda *tensors, **kwargs: -1), + (torch.min, lambda input, out=None: -1), + (torch.miopen_batch_norm, lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, + epsilon: -1), + (torch.miopen_convolution, lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1), + (torch.miopen_convolution_transpose, lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, + deterministic: -1), + (torch.miopen_depthwise_convolution, lambda input, weight, bias, padding, stride, dilation, groups, benchmark, + deterministic: -1), + (torch.miopen_rnn, lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, + bidirectional, batch_sizes, dropout_state: -1), + (torch.mm, lambda input, mat2, out=None: -1), + (torch.mode, lambda input: -1), + (torch.mul, lambda input, other, out=None: -1), + (torch.multinomial, lambda input, num_samples, replacement=False, out=None: -1), + (torch.mv, lambda input, vec, out=None: -1), + (torch.mvlgamma, lambda input, p: -1), + (torch.narrow, lambda input, dim, start, length: -1), + (torch.native_batch_norm, lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1), + (torch.native_layer_norm, lambda input, weight, bias, M, N, eps: -1), + (torch.native_norm, lambda input, p=2: -1), + (torch.ne, lambda input, other, out=None: -1), + (torch.neg, lambda input, out=None: -1), + (torch.nonzero, lambda input, out=None, as_tuple=False: -1), + (torch.norm, lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1), + (torch.norm_except_dim, lambda v, pow=2, dim=0: -1), + (torch.normal, lambda mean, std, out=None: -1), + (torch.nuclear_norm, lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1), + (torch.numel, lambda input: -1), + (torch.orgqr, lambda input1, input2: -1), + (torch.ormqr, lambda input, input2, input3, left=True, transpose=False: -1), + (torch.pairwise_distance, lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1), + (torch.pdist, lambda input, p=2: -1), + (torch.pinverse, lambda input, rcond=1e-15: -1), + (torch.pixel_shuffle, lambda input, upscale_factor: -1), + (torch.poisson, lambda input, generator=None: -1), + (torch.poisson_nll_loss, lambda input, target, log_input, full, eps, reduction: -1), + (torch.polygamma, lambda input, n, out=None: -1), + (torch.prelu, lambda input, weight: -1), + (torch.ones_like, lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1), + (torch.pow, lambda input, exponent, out=None: -1), + (torch.prod, lambda input: -1), + (torch.q_per_channel_axis, lambda input: -1), + (torch.q_per_channel_scales, lambda input: -1), + (torch.q_per_channel_zero_points, lambda input: -1), + (torch.q_scale, lambda input: -1), + (torch.q_zero_point, lambda input: -1), + (torch.qr, lambda input, some=True, out=None: -1), + (torch.quantize_per_channel, lambda input, scales, zero_points, axis, dtype: -1), + (torch.quantize_per_tensor, lambda input, scale, zero_point, dtype: -1), + (torch.quantized_gru, lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1), + (torch.quantized_gru_cell, lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, + scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + (torch.quantized_lstm, lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, + dtype=None, use_dynamic=False: -1), + (torch.quantized_lstm_cell, lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, + scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + (torch.quantized_max_pool2d, lambda input, kernel_size, stride, padding, dilation, ceil_mode=False: -1), + (torch.quantized_rnn_relu_cell, lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + (torch.quantized_rnn_tanh_cell, lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + (torch.rand_like, lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1), + (torch.randint_like, lambda input, low, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1), + (torch.randn_like, lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1), + (torch.real, lambda input, out=None: -1), + (torch.reciprocal, lambda input, out=None: -1), + (torch.relu, lambda input, inplace=False: -1), + (torch.remainder, lambda input, other, out=None: -1), + (torch.renorm, lambda input, p, dim, maxnorm, out=None: -1), + (torch.repeat_interleave, lambda input, repeats, dim=None: -1), + (torch.reshape, lambda input, shape: -1), + (torch.result_type, lambda tensor1, tensor2: -1), + (torch.rfft, lambda input, signal_ndim, normalized=False, onesided=True: -1), + (torch.rnn_relu, lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1), + (torch.rnn_relu_cell, lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1), + (torch.rnn_tanh, lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1), + (torch.rnn_tanh_cell, lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1), + (torch.roll, lambda input, shifts, dims=None: -1), + (torch.rot90, lambda input, k, dims: -1), + (torch.round, lambda input, out=None: -1), + (torch.rrelu, lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1), + (torch.rsqrt, lambda input, out=None: -1), + (torch.rsub, lambda input, other, alpha=1: -1), + (torch.saddmm, lambda input, mat1, mat2, beta, alpha, out=None: -1), + (torch.scalar_tensor, lambda s, dtype=None, layour=None, device=None, pin_memory=None: -1), + (torch.scatter, lambda input, dim, index, src: -1), + (torch.scatter_add, lambda input, dim, index, src: -1), + (torch.select, lambda input, dim, index: -1), + (torch.selu, lambda input, inplace=False: -1), + (torch.sigmoid, lambda input, out=None: -1), + (torch.sign, lambda input, out=None: -1), + (torch.sin, lambda input, out=None: -1), + (torch.sinh, lambda input, out=None: -1), + (torch.slogdet, lambda input: -1), + (torch.smm, lambda input, mat2: -1), + (torch.spmm, lambda input, mat2: -1), + (torch.softmax, lambda input, dim, dtype=None: -1), + (torch.solve, lambda input, A, out=None: -1), + (torch.sort, lambda input, dim=-1, descending=False, out=None: -1), + (torch.split, lambda tensor, split_size_or_sections, dim=0: -1), + (torch.split_with_sizes, lambda tensor, split_size_or_sections, dim=0: -1), + (torch.sqrt, lambda input, out=None: -1), + (torch.squeeze, lambda input, dim=None, out=None: -1), + (torch.sspaddmm, lambda input, mat1, mat2, beta, alpha, out=None: -1), + (torch.stack, lambda tensors, dim=0, out=None: -1), + (torch.std, lambda input: -1), + (torch.std_mean, lambda input: -1), + (torch.stft, lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', + normalized=False, onesided=True: -1), + (torch.sub, lambda input, other, out=None: -1), + (torch.sum, lambda input: -1), + (torch.svd, lambda input, some=True, compute_uv=True, out=None: -1), + (torch.symeig, lambda input, eigenvectors=False, upper=True, out=None: -1), + (torch.t, lambda input: -1), + (torch.take, lambda input, index: -1), + (torch.tan, lambda input, out=None: -1), + (torch.tanh, lambda input, out=None: -1), + (torch.tensordot, lambda a, b, dims=2: -1), + (torch.threshold, lambda input, threshold, value, inplace=False: -1), + (torch.topk, lambda input, k, dim=-1, descending=False, out=None: -1), + (torch.trace, lambda input: -1), + (torch.transpose, lambda input, dim0, dim1: -1), + (torch.trapz, lambda y, x, dim=-1: -1), + (torch.triangular_solve, lambda input, A, upper=True, transpose=False, unitriangular=False: -1), + (torch.tril, lambda input, diagonal=0, out=None: -1), + (torch.tril_indices, lambda row, col, offset=0, dtype=torch.long, device='cpu', layout=torch.strided: -1), + (torch.triplet_margin_loss, lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, + reduce=None, reduction='mean': -1), + (torch.triu, lambda input, diagonal=0, out=None: -1), + (torch.triu_indices, lambda row, col, offset=0, dtype=torch.long, device='cpu', layout=torch.strided: -1), + (torch.trunc, lambda input, out=None: -1), + (torch.unbind, lambda input, dim=0: -1), + (torch.unique, lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1), + (torch.unique_consecutive, lambda input, return_inverse=False, return_counts=False, dim=None: -1), + (torch.unsqueeze, lambda input, dim, out=None: -1), + (torch.var, lambda input: -1), + (torch.var_mean, lambda input: -1), + (torch.where, lambda condition, x, y: -1), + (torch.zeros_like, lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1), +) + +TENSOR_LIKE_OVERRIDES = tuple(t[0] for t in TENSOR_LIKE_TORCH_IMPLEMENTATIONS) + +def generate_tensor_like_torch_implementations(): + torch_vars = vars(torch) + untested_funcs = [] + for func_name in torch.__all__ + dir(torch._C._VariableFunctions): + # ignore private functions or functions that are deleted in torch.__init__ + if func_name.startswith('_') or func_name == 'unique_dim': + continue + func = getattr(torch, func_name) + # IGNORED_TORCH_FUNCTIONS are functions that are public but cannot be + # overriden by __torch_function__ + if func in IGNORED_TORCH_FUNCTIONS: + msg = "torch.{} is in IGNORED_TORCH_FUNCTIONS but still has an explicit override" + assert func not in TENSOR_LIKE_OVERRIDES, msg.format(func.__name__) + continue + # ignore in-place operators + if func_name.endswith('_'): + continue + # only consider objects with lowercase names + if not func_name.islower(): + continue + if func not in TENSOR_LIKE_OVERRIDES: + untested_funcs.append("torch.{}".format(func.__name__)) + msg = ( + "The following functions are not tested for __torch_function__ " + "support, please either add an entry in " + "TENSOR_LIKE_TORCH_IMPLEMENTATIONS for this function or if a " + "__torch_function__ override does not make sense, add an entry to " + "IGNORED_TORCH_FUNCTIONS.\n\n{}" + ) + assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) + for func, override in TENSOR_LIKE_TORCH_IMPLEMENTATIONS: + # decorate the overrides with implements_tensor_like + implements_tensor_like(func)(override) + +generate_tensor_like_torch_implementations() + +class TensorLike(object): + """A class that overrides the full torch API + + This class is used to explicitly test that the full torch.tensor API + can be overriden with a class that defines __torch_function__. + """ + def __torch_function__(self, func, args=(), kwargs=None): + if(kwargs is None): + kwargs = {} + + if func not in HANDLED_FUNCTIONS_TENSOR_LIKE: + return NotImplemented + # In this case _torch_function_ should override TensorLike objects + return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs) + +class TestTorchFunctionOverride(TestCase): + def test_mean(self): + """Test that a function with one argument can be overrided""" + t1 = DiagonalTensor(5, 2) + t2 = SubTensor([[1, 2], [1, 2]]) + t3 = SubDiagonalTensor(5, 2) + self.assertEqual(torch.mean(t1), 12.5) + self.assertEqual(bar(t1), -1) + self.assertEqual(torch.mean(t2), 0) + self.assertEqual(bar(t2), t2) + self.assertEqual(torch.mean(t3), 125) + self.assertEqual(bar(t3), 0) + + def test_mm(self): + """Test that a function with multiple arguments can be overrided""" + t1 = DiagonalTensor(5, 2) + t2 = torch.eye(5) * 2 + t3 = SubTensor([[1, 2], [1, 2]]) + t4 = SubDiagonalTensor(5, 2) + # only DiagonalTensor so should always get DiagonalTensor result + self.assertEqual(torch.mm(t1, t1), 0) + # tensor and DiagonalTensor, always return DiagonalTensor result + self.assertEqual(torch.mm(t1, t2), 0) + self.assertEqual(torch.mm(t2, t1), 0) + # only SubTensor so should always get SubTensor result + self.assertEqual(torch.mm(t3, t3), -1) + # tensor and SubTensor so should always get SubTensor result + self.assertEqual(torch.mm(t3, t2), -1) + self.assertEqual(torch.mm(t2, t3), -1) + # DiagonalTensor and SubTensor are unrelated classes so the result + # depends on which argument appears first + self.assertEqual(torch.mm(t3, t1), -1) + self.assertEqual(torch.mm(t1, t3), 0) + # SubDiagonalTensor should take precedence over DiagonalTensor + # but should behave otherwise the same as DiagonalTensor + self.assertEqual(torch.mm(t4, t4), 1) + self.assertEqual(torch.mm(t4, t1), 1) + self.assertEqual(torch.mm(t1, t4), 1) + self.assertEqual(torch.mm(t4, t2), 1) + self.assertEqual(torch.mm(t2, t4), 1) + self.assertEqual(torch.mm(t3, t4), -1) + self.assertEqual(torch.mm(t4, t3), 0) + + def test_precedence_semantics(self): + """Test semantics for __torch_function__ for functions that take + multiple arugments + + For functions that take multiple arguments, the appropriate + __torch_function__ implementation to call is determined by + examining the types of the arguments. The precedence order is + left-to-right in the argument list, except subclasses are always + checked before superclasses. The first result of calling the + implementations in precedence order that is not NotImplemented + is returned to the user. If all implementations return + NotImplemented, a TypeError is raised. + + All cases are tested with functions implemented in C++ and + either foo or baz, which are python functions defined above that + are instrumented to obey the same dispatch rules as the + functions in torch.functional. + """ + # DiagonalTensor has a valid override and SubDiagonal has an + # override that returns NotImplemented so we should call the + # DiagonalTensor implementation, returning -1 + t1 = DiagonalTensor(5, 2) + t2 = SubDiagonalTensor(5, 2) + self.assertEqual(torch.div(t1, t2), -1) + self.assertEqual(torch.div(t2, t1), -1) + self.assertEqual(foo(t1, t2), -1) + self.assertEqual(foo(t2, t1), -1) + + # SubTensor has an implementation that returns NotImplemented as + # well so it should behave exactly like SubDiagonalTensor in the + # test above + t3 = SubTensor([[1, 2], [1, 2]]) + self.assertEqual(torch.div(t1, t3), -1) + self.assertEqual(torch.div(t3, t1), -1) + self.assertEqual(foo(t1, t3), -1) + self.assertEqual(foo(t3, t1), -1) + + # div between SubTensor and SubDiagonalTensor should raise + # TypeError since both have an implementation that + # explicitly returns NotImplemented + with self.assertRaises(TypeError): + torch.div(t2, t3) + with self.assertRaises(TypeError): + torch.div(t3, t2) + with self.assertRaises(TypeError): + foo(t2, t3) + with self.assertRaises(TypeError): + foo(t3, t2) + + # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a + # mul or a baz implementation so all ops should raise TypeError + with self.assertRaises(TypeError): + torch.mul(t1, t1) + with self.assertRaises(TypeError): + torch.mul(t1, t2) + with self.assertRaises(TypeError): + torch.mul(t1, t3) + with self.assertRaises(TypeError): + torch.mul(t2, t1) + with self.assertRaises(TypeError): + torch.mul(t2, t2) + with self.assertRaises(TypeError): + torch.mul(t2, t3) + with self.assertRaises(TypeError): + torch.mul(t3, t1) + with self.assertRaises(TypeError): + torch.mul(t3, t2) + with self.assertRaises(TypeError): + torch.mul(t3, t3) + with self.assertRaises(TypeError): + baz(t1, t1) + with self.assertRaises(TypeError): + baz(t1, t2) + with self.assertRaises(TypeError): + baz(t1, t3) + with self.assertRaises(TypeError): + baz(t2, t1) + with self.assertRaises(TypeError): + baz(t2, t2) + with self.assertRaises(TypeError): + baz(t2, t3) + with self.assertRaises(TypeError): + baz(t3, t1) + with self.assertRaises(TypeError): + baz(t3, t2) + with self.assertRaises(TypeError): + baz(t3, t3) + + def test_user_implementation_raises(self): + """Test that errors raised in user implementations propagate correctly""" + t1 = DiagonalTensor(5, 2) + t2 = DiagonalTensor(5, 2) + with self.assertRaises(ValueError): + torch.add(t1, t2) + with self.assertRaises(ValueError): + quux(t1) + +def generate_tensor_like_override_tests(cls): + def test_generator(func, override): + if torch._six.PY3: + args = inspect.getfullargspec(override) + else: + args = inspect.getargspec(override) + nargs = len(args.args) + if args.defaults is not None: + nargs -= len(args.defaults) + func_args = [TensorLike() for _ in range(nargs)] + if args.varargs is not None: + func_args += [TensorLike(), TensorLike()] + + def test(self): + self.assertEqual(func(*func_args), -1) + + return test + + for func, override in TENSOR_LIKE_TORCH_IMPLEMENTATIONS: + test_method = test_generator(func, override) + name = 'test_{}'.format(func.__name__) + test_method.__name__ = name + setattr(cls, name, test_method) + +generate_tensor_like_override_tests(TestTorchFunctionOverride) + +if __name__ == '__main__': + unittest.main() diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 62de2e29427fb2..105e23c00a26f5 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -80,6 +80,7 @@ ${unpack_self} ParsedArgs<${max_args}> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); + ${check_has_torch_function} ${declare_namedtuple_return_types} ${dispatch} Py_RETURN_NONE; @@ -87,6 +88,16 @@ } """) +TORCH_FUNCTION_CHECK = """\ +if(r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPVariableFunctions); +} +""" + +PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\ +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); +""") + PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args) { @@ -310,6 +321,7 @@ def get_type_default(declaration): def create_python_bindings(python_functions, has_self, is_module=False): """Generates Python bindings to ATen functions""" + py_signatures = [] py_methods = [] py_method_defs = [] py_method_dispatch = [] @@ -728,6 +740,7 @@ def process_function(name, declarations): 'unpack_self': [], 'dispatch': [], 'declare_namedtuple_return_types': '', + 'check_has_torch_function': '', } if has_self: @@ -770,6 +783,8 @@ def process_function(name, declarations): if not is_module and not has_self: env['flags'] += ' | METH_STATIC' + env['check_has_torch_function'] = TORCH_FUNCTION_CHECK + py_signatures.append(PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION.substitute(env)) py_methods.append(tmpl.substitute(env)) if name in BINARY_OP_NAMES: @@ -781,6 +796,7 @@ def process_function(name, declarations): process_function(name, python_functions[name]) return { + 'py_signatures': py_signatures, 'py_methods': py_methods, 'py_method_defs': py_method_defs, 'py_method_dispatch': py_method_dispatch, diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 03c90d257ceca6..4078440e40c898 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -15,6 +15,7 @@ #include "torch/csrc/Dtype.h" #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/Exceptions.h" +#include "torch/csrc/utils/pybind.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/tensor_layouts.h" #include "torch/csrc/utils/tensor_new.h" @@ -326,30 +327,7 @@ static std::vector dispatch_nonzero_numpy(const Tensor & self) { return self.nonzero_numpy(); } -static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "nonzero(Tensor input, *, Tensor out=None)|deprecated", - "nonzero(Tensor input, *, bool as_tuple)", - }); - ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - if (r.idx == 0) { - if (r.isNone(1)) { - return wrap(dispatch_nonzero(r.tensor(0))); - } else { - return wrap(dispatch_nonzero(r.tensor(0), r.tensor(1))); - } - } else { - if (r.toBool(1)) { - return wrap(dispatch_nonzero_numpy(r.tensor(0))); - } else { - return wrap(dispatch_nonzero(r.tensor(0))); - } - } - END_HANDLE_TH_ERRORS -} +static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs); static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs) { @@ -378,6 +356,7 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); + if (r.idx == 0) { return wrap(r.tensor(0).get_device()); } @@ -385,22 +364,11 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "numel(Tensor input)", - }, /*traceable=*/false); +static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs); - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); +// generated forward declarations start here - if (r.idx == 0) { - return wrap(r.tensor(0).numel()); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} +${py_signatures} // Wrapper converts a raised TypeError into returning NotImplemented // Used to implement binary arithmetic operators @@ -415,10 +383,6 @@ static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, P return ret; } -// generated methods start here - -${py_methods} - // XXX: ops that are bound here are not exposed to the C++ api nor the JIT. // Any new ops added here should be accompanied with a comment why they are not // being registered through native_functions.yaml, and be tagged cpp / JIT @@ -492,4 +456,132 @@ void initTorchFunctions(PyObject* module) { } } +/* + * + * Calls __torch_function__ on the overloaded arguments to a torch API + * function in order of precedence, returning the first result that is + * not NotImplemented. If all arguments return NotImplemented, raises a + * TypeError. + * + * Assumes overloaded_args has at least one entry. All entries must have + * a __torch_function__ attribute that resolves to a callable that + * accepts a torch API function, arguments, and keyword arguments for + * the torch API function. + * + * It is sufficient to call PythonArgs::has_torch_function before + * calling this function to verify that there are valid arguments + * present. If that is not done then special care must be taken to + * ensure there are arguments that are overloaded with + * __torch_function__. + * + * See torch._overrides._implement_torch_function for the equivalent + * code in the pure-python implementation. + * + * 'r' is a parsed PythonArgs instance, returned from + * PythonArgParser::parse. + * + * 'args' is a reference to the python tuple of arguments to the torch + * API function. + * + * 'kwargs' is a reference to the python dict of keyword arguments to + * the torch API function. + * + * 'torch_api' is a reference to python torch API namespace. + * + */ + +PyObject* handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyTypeObject &torch_api) { + py::object torch_api_function = PyObject_FastGetAttrString((PyObject*)&torch_api, const_cast(r.get_func_name().data())); + TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != NULL, "torch API function must exist"); + py::object ret; + for (auto &arg : r.signature.overloaded_args) { + py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__"); + ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), args, kwargs, NULL)); + if (ret.ptr() != Py_NotImplemented) { + // Return the reference to the result. This also covers the case where ret + // is NULL and __torch_function__ raised an exception, which we throw below + break; + } + } + if (ret.ptr() == nullptr) { + // if an exception occurred in a user's implementation of + // __array_function__, throw it + throw python_error(); + } + else if (ret.ptr() == Py_NotImplemented) { + // all __torch_function__ implementations in overloaded_args + // returned NotImplemented, so we raise a TypeError. + std::stringstream ss; + ss << "no implementation found for 'torch." << r.get_func_name() + << "' on types that implement __torch_function__: ["; + for (auto &arg : r.signature.overloaded_args) { + ss << arg.ptr()->ob_type->tp_name; + if (!arg.is(r.signature.overloaded_args.back())) { + ss << ", "; + } + else { + ss << "]"; + } + } + const std::string& tmp = ss.str(); + PyErr_SetString(PyExc_TypeError, tmp.c_str()); + throw python_error(); + } + return ret.release().ptr(); +} + +// generated methods start here + +${py_methods} + +static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "nonzero(Tensor input, *, Tensor out=None)|deprecated", + "nonzero(Tensor input, *, bool as_tuple)", + }); + ParsedArgs<2> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, args, kwargs, THPVariableFunctions); + } + + if (r.idx == 0) { + if (r.isNone(1)) { + return wrap(dispatch_nonzero(r.tensor(0))); + } else { + return wrap(dispatch_nonzero(r.tensor(0), r.tensor(1))); + } + } else { + if (r.toBool(1)) { + return wrap(dispatch_nonzero_numpy(r.tensor(0))); + } else { + return wrap(dispatch_nonzero(r.tensor(0))); + } + } + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "numel(Tensor input)", + }, /*traceable=*/false); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, args, kwargs, THPVariableFunctions); + } + + if (r.idx == 0) { + return wrap(r.tensor(0).numel()); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} }} // namespace torch::autograd diff --git a/torch/_overrides.py b/torch/_overrides.py new file mode 100644 index 00000000000000..6dccba7c3aae9b --- /dev/null +++ b/torch/_overrides.py @@ -0,0 +1,280 @@ +""" +Python implementation of __torch_function__ + +While most of the torch API and handling for __torch_function__ happens +at the C++ level, some of the torch API is written in Python so we need +python-level handling for __torch_function__ overrides as well. The main +developer-facing functionality in this file is the +torch_function_dispatch decorator. This function can be applied to +python functions in the torch.functional module to enable +__torch_function__ overrides for that function. See the examples in the +docstrings for torch_function_dispatch for details. + +NOTE: heavily inspired by NumPy's ``__array_function__`` (see: +https://github.com/pytorch/pytorch/issues/24015 and +https://www.numpy.org/neps/nep-0018-array-function-protocol.html +) + +""" + +import functools +import textwrap +from . import _six +if _six.PY3: + from inspect import getfullargspec + import collections + ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') + + def getargspec(func): + spec = getfullargspec(func) + return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults) +else: + from inspect import getargspec + +from .tensor import Tensor + + +_TENSOR_ONLY = [Tensor] + +def _get_overloaded_types_and_args(relevant_args): + """Returns a list of arguments on which to call __torch_function__. + + Checks arguments in relevant_args for __torch_function__ implementations, + storing references to the arguments and their types in overloaded_args and + overloaded_types in order of calling precedence. Only distinct types are + considered. If a type is a subclass of another type it will have higher + precedence, otherwise the precedence order is the same as the order of + arguments in relevant_args, that is, from left-to-right in the argument list. + + The precedence-determining algorithm implemented in this function is + described in `NEP-0018`_. + + See torch::append_overloaded_arg for the equivalent function in the C++ + implementation. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of array-like arguments to check for __torch_function__ + methods. + + Returns + ------- + overloaded_types : collection of types + Types of arguments from relevant_args with __torch_function__ methods. + overloaded_args : list + Arguments from relevant_args on which to call __torch_function__ + methods, in the order in which they should be called. + + .. _NEP-0018: + https://numpy.org/neps/nep-0018-array-function-protocol.html + + """ + # Runtime is O(num_arguments * num_unique_types) + overloaded_types = [] + overloaded_args = [] + for arg in relevant_args: + arg_type = type(arg) + # We only collect arguments if they have a unique type, which ensures + # reasonable performance even with a long list of possibly overloaded + # arguments. + if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')): + # Create lists explicitly for the first type (usually the only one + # done) to avoid setting up the iterator for overloaded_args. + if overloaded_types: + overloaded_types.append(arg_type) + # By default, insert argument at the end, but if it is + # subclass of another argument, insert it before that argument. + # This ensures "subclasses before superclasses". + index = len(overloaded_args) + for i, old_arg in enumerate(overloaded_args): + if issubclass(arg_type, type(old_arg)): + index = i + break + overloaded_args.insert(index, arg) + else: + overloaded_types = [arg_type] + overloaded_args = [arg] + + return overloaded_types, overloaded_args + + +def _implement_torch_function( + implementation, public_api, relevant_args, args, kwargs): + """Implement a function with checks for __torch_function__ overrides. + + See torch::autograd::handle_torch_function for the equivalent of this + function in the C++ implementation. + + Arguments + --------- + implementation : function + Function that implements the operation on ``torch.Tensor`` without + overrides when called like ``implementation(*args, **kwargs)``. + public_api : function + Function exposed by the public torch API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being + checked. + relevant_args : iterable + Iterable of arguments to check for __torch_function__ methods. + args : tuple + Arbitrary positional arguments originally passed into ``public_api``. + kwargs : tuple + Arbitrary keyword arguments originally passed into ``public_api``. + + Returns + ------- + Result from calling `implementation()` or an `__torch_function__` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + + """ + # Check for __torch_function__ methods. + types, overloaded_args = _get_overloaded_types_and_args(relevant_args) + # Short-cut for common cases: no overload or only Tensor overload + # (directly or with subclasses that do not override __torch_function__). + if not overloaded_args or types == _TENSOR_ONLY: + return implementation(*args, **kwargs) + + # Call overrides + for overloaded_arg in overloaded_args: + # Use `public_api` instead of `implementation` so __torch_function__ + # implementations can do equality/identity comparisons. + result = overloaded_arg.__torch_function__(public_api, args, kwargs) + + if result is not NotImplemented: + return result + + func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) + raise TypeError("no implementation found for '{}' on types that implement " + '__torch_function__: {}' + .format(func_name, list(map(type, overloaded_args)))) + + +def _verify_matching_signatures(implementation, dispatcher): + """Verify that a dispatcher function has the right signature.""" + implementation_spec = getargspec(implementation) + dispatcher_spec = getargspec(dispatcher) + + if (implementation_spec.args != dispatcher_spec.args or + implementation_spec.varargs != dispatcher_spec.varargs or + implementation_spec.keywords != dispatcher_spec.keywords or + (bool(implementation_spec.defaults) != + bool(dispatcher_spec.defaults)) or + (implementation_spec.defaults is not None and + len(implementation_spec.defaults) != + len(dispatcher_spec.defaults))): + raise RuntimeError('implementation and dispatcher for %s have ' + 'different function signatures' % implementation) + + +_wrapped_func_source = textwrap.dedent(""" + @functools.wraps(implementation) + def {name}(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return implement_torch_function( + implementation, {name}, relevant_args, args, kwargs) + """) + +def torch_function_dispatch(dispatcher, module=None, verify=True): + """Decorator for adding dispatch with the __torch_function__ protocol. + + If you define a function in Python and would like to permit user-defined + tensor-like types to override it using __torch_function__, please apply this + decorator on this function together with a custom dispatcher that indicates + which arguments should be checked for the presence of __torch_function__. + + Suppose we'd like to apply this function to torch.frob, which has the + following definition: + + def frob(input, bias, option=None): + return input + bias + + We'd need to define a dispatcher for frob that has the same signature and + returns the elements of the signature that should be checked for + `__torch_function__`. If any of the arguments has a `__torch_function__` + attribute, that function will be called to handle custom dispatch. Assuming + that `bias` can be a tensor-like, our dispatcher would look like: + + def _frob_dispatcher(input, bias, option=None): + return (input, bias) + + The dispatcher must return an iterable, so return a single-element tuple if + only one argument should be checked. We would then modify the original + definition for torch.frob to look like: + + @torch_function_dispatch(_frob_dispatcher) + def frob(input, bias, option=None): + return input + bias + + See ``torch/functional.py`` for more usage examples. + + Parameters + ---------- + dispatcher : callable + Function that when called like ``dispatcher(*args, **kwargs)`` with + arguments from the NumPy function call returns an iterable of + array-like arguments to check for ``__torch_function__``. + module : str, optional + ``__module__`` attribute to set on new function, e.g., + ``module='torch'``. By default, module is copied from the decorated + function. + verify : bool, optional + If True, verify the that the signature of the dispatcher and decorated + function signatures match exactly: all required and optional arguments + should appear in order with the same names, but the default values for + all optional arguments should be ``None``. Only disable verification + if the dispatcher's signature needs to deviate for some particular + reason, e.g., because the function has a signature like + ``func(*args, **kwargs)``. + + Returns + ------- + dispatcher : callable + Function suitable for decorating the implementation of a NumPy + function. + + Notes + ----- + The dispatcher should normally return a tuple containing all input + arguments that may have a ``__torch_function__`` attribute. + + In some cases where that's not easily possible, e.g. ``torch.cat``, it is + also valid (if a little slower) to make the dispatcher function a generator + (i.e. use ``yield`` to return arguments one by one). + + """ + def decorator(implementation): + if verify: + _verify_matching_signatures(implementation, dispatcher) + + # Equivalently, we could define this function directly instead of using + # exec. This version has the advantage of giving the helper function a + # more interpretable name. Otherwise, the original function does not + # show up at all in many cases, e.g., if it's written in C++ or if the + # dispatcher gets an invalid keyword argument. + source = _wrapped_func_source.format(name=implementation.__name__) + + source_object = compile( + source, filename='<__torch_function__ internals>', mode='exec') + scope = { + 'implementation': implementation, + 'dispatcher': dispatcher, + 'functools': functools, + 'implement_torch_function': _implement_torch_function, + } + _six.exec_(source_object, scope) + + public_api = scope[implementation.__name__] + + if module is not None: + public_api.__module__ = module + + public_api._implementation = implementation + + return public_api + + return decorator diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 9ce30b23adb626..e57057e2326e2a 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -132,10 +132,86 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) } } -bool FunctionParameter::check(PyObject* obj) { +/* + * obj has a __torch_function__ implementation and may either be a + * subclass of Tensor or a Tensor-like duck type. We may need to + * append this object to the overloaded_args vector, which tracks all + * of the arguments with distinct __torch_function__ implementations + * we've seen so far. + * + * If this is the first argument we've seen with __torch_function__ + * defined, we unconditionally add obj to the overloaded_args vector. + * + * If we've already seen arguments with __torch_function__ defined, + * then we first need to check if obj is the same type as any of the + * entries in overloaded_args. If so, we can ignore obj since we + * already have an entry in overloaded_args with the same + * __torch_function__ implementation. + * + * If it's a different type, we then need to check if it's a subclass + * of one of the types we've already seen. If so, we need to insert an + * entry in overloaded_args for this type with higher precedence than + * the superclass. + * + * See torch._overrides._get_overloaded_types_and_args for the equivalent + * function in the Python __torch_function__ implementation. + * + * The precedence-determining algorithm implemented in this function is + * described in NEP-0018: + * https://numpy.org/neps/nep-0018-array-function-protocol.html + * + * 'overloaded_args' is a reference to a vector of pybind11 handles + * that have distinct __torch_function__ implementations, in order of calling + * precedence. + * + * 'obj' is an object to check for a __torch_function__ implementation + * + */ + +void append_overloaded_arg(std::vector &overloaded_args, PyObject* obj) { + bool class_not_seen_yet = true; + for (auto &arg : overloaded_args) { + if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) { + // obj is the same type as another parameter we've seen in a prior + // iteration of the loop over parameters so we already have an entry + // with the proper __torch_function__ implementation to call, so skip + // this parameter + class_not_seen_yet = false; + break; + } + } + if (class_not_seen_yet) { + int arg_index = overloaded_args.size(); + for (int j = 0; j < arg_index; j++) { + if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE(overloaded_args[j].ptr())))) { + // obj is a subclass of another object we've seen already so its + // __torch_function__ should be called first, therefore we + // insert it into overloaded_args before the superclass + arg_index = j; + break; + } + } + // add object to overloaded_args. If it's a subclass of another class + // we've already seen it will be inserted before the superclass, + // otherwise it will be inserted at the end of the array + overloaded_args.insert(overloaded_args.begin() + arg_index, obj); + } +} + +auto FunctionParameter::check(PyObject* obj, std::vector &overloaded_args) -> bool +{ switch (type_) { case ParameterType::TENSOR: { - return THPVariable_Check(obj) || (allow_numbers_as_tensors && THPUtils_checkScalar(obj)); + if (THPVariable_CheckExact(obj)) { + return true; + } + if (THPVariable_Check(obj)) { + if (check_has_torch_function(obj)) { + append_overloaded_arg(overloaded_args, obj); + } + return true; + } + return allow_numbers_as_tensors && THPUtils_checkScalar(obj); } case ParameterType::SCALAR: case ParameterType::COMPLEX: @@ -500,6 +576,10 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], return false; } + if (!overloaded_args.empty()) { + overloaded_args.clear(); + } + int i = 0; for (auto& param : params) { PyObject* obj = nullptr; @@ -532,11 +612,14 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], missing_args(*this, i); } return false; - } else if (param.check(obj)) { + } else if (param.check(obj, this->overloaded_args)) { dst[i++] = obj; // XXX: the Variable check is necessary because sizes become tensors when // tracer is enabled. This behavior easily leads to ambiguities, and we // should avoid having complex signatures that make use of it... + } else if (check_has_torch_function(obj)) { + append_overloaded_arg(overloaded_args, obj); + dst[i++] = obj; } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd && THPUtils_checkIndex(obj)) { // take all positional arguments as this parameter @@ -574,7 +657,6 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], } return false; } - return true; } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index b4e2ab4152a1ad..3432ab21b54132 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -59,6 +59,7 @@ #include #include #include +#include #include #include #include @@ -96,6 +97,7 @@ struct ParsedArgs { struct PythonArgParser { explicit PythonArgParser(std::vector fmts, bool traceable=false); + // meant only for `torch` functions. template inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst); @@ -122,6 +124,8 @@ struct PythonArgs { const FunctionSignature& signature; PyObject** args; + inline bool has_torch_function(); + inline std::string get_func_name(); inline at::Tensor tensor(int i); inline at::Scalar scalar(int i); inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar); @@ -168,14 +172,16 @@ struct PythonArgs { at::Scalar scalar_slow(int i); }; -struct FunctionSignature { +struct PYBIND11_EXPORT FunctionSignature { explicit FunctionSignature(const std::string& fmt); bool parse(PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception); + std::string toString() const; std::string name; std::vector params; + std::vector overloaded_args; ssize_t min_args; ssize_t max_args; ssize_t max_pos_args; @@ -186,7 +192,8 @@ struct FunctionSignature { struct FunctionParameter { FunctionParameter(const std::string& fmt, bool keyword_only); - bool check(PyObject* obj); + bool check(PyObject* obj, std::vector &overloaded_args); + void set_default_str(const std::string& str); std::string type_name() const; @@ -222,6 +229,14 @@ inline PythonArgs PythonArgParser::parse(PyObject* args, PyObject* kwargs, Parse return raw_parse(args, kwargs, dst.args); } +inline bool PythonArgs::has_torch_function(){ + return !this->signature.overloaded_args.empty(); +} + +inline std::string PythonArgs::get_func_name(){ + return signature.name; +} + inline at::Tensor PythonArgs::tensor(int i) { if (args[i] && THPVariable_CheckExact(args[i])) { return reinterpret_cast(args[i])->cdata; @@ -530,4 +545,129 @@ inline PyObject* PythonArgs::pyobject(int i) { return args[i]; } +/* + * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 + * + * Stripped down version of PyObject_GetAttrString, + * avoids lookups for None, tuple, and List objects, + * and doesn't create a PyErr since this code ignores it. + * + * This can be much faster then PyObject_GetAttrString where + * exceptions are not used by caller. + * + * 'obj' is the object to search for attribute. + * + * 'name' is the attribute to search for. + * + * Returns a py::object wrapping the return value. If the attribute lookup failed + * the value will be NULL. + * + */ + +static py::object PyObject_FastGetAttrString(PyObject *obj, char *name) +{ + PyTypeObject *tp = Py_TYPE(obj); + PyObject *res = (PyObject *)NULL; + + /* Attribute referenced by (char *)name */ + if (tp->tp_getattr != NULL) { + res = (*tp->tp_getattr)(obj, name); + if (res == NULL) { + PyErr_Clear(); + } + } + /* Attribute referenced by (PyObject *)name */ + else if (tp->tp_getattro != NULL) { + PyObject *w = THPUtils_internString(name); + if (w == NULL) { + return py::object(); + } + res = (*tp->tp_getattro)(obj, w); + Py_DECREF(w); + if (res == NULL) { + PyErr_Clear(); + } + } + return py::reinterpret_steal(res); +} + +// Makes sure that we don't check for __torch_function__ on basic Python types +static bool _is_basic_python_type(PyTypeObject *tp) +{ + return ( + /* Basic number types */ + tp == &PyBool_Type || + + tp == &PyLong_Type || + tp == &PyFloat_Type || + tp == &PyComplex_Type || + + /* Basic sequence types */ + tp == &PyList_Type || + tp == &PyTuple_Type || + tp == &PyDict_Type || + tp == &PySet_Type || + tp == &PyFrozenSet_Type || + tp == &PyUnicode_Type || + tp == &PyBytes_Type || + +#if PY_MAJOR_VERSION == 2 + tp == &PyString_Type || +#endif + + /* other builtins */ + tp == &PySlice_Type || + tp == Py_TYPE(Py_None) || + tp == Py_TYPE(Py_Ellipsis) || + tp == Py_TYPE(Py_NotImplemented) || + + PyModule_Check(tp) || + /* sentinel to swallow trailing || */ + false + ); +} + +/* + * Lookup a special method, following the python approach of looking up + * on the type object, rather than on the instance itself. + * + * Assumes that the special method is a torch-specific one, so does not + * look at builtin types, nor does it look at a base Tensor. + * + * If no special method is found, return NULL, otherwise returns a new + * reference to the function object + * + * In future, could be made more like _Py_LookupSpecial + */ + +static py::object PyTorch_LookupSpecial(PyObject *obj, char* name) +{ + PyTypeObject *tp = Py_TYPE(obj); + if (THPVariable_CheckExact(obj)) { + return py::object(); + } + if (_is_basic_python_type(tp)) { + return py::object(); + } + if(PyObject_HasAttrString(obj, name) == 0){ + return py::object(); + } + return PyObject_FastGetAttrString((PyObject *)tp, name); +} + +/* + * Checks if obj has a __torch_function__ implementation + * + * Returns true if an implementation is found and false otherwise + * + */ +static auto check_has_torch_function(PyObject* obj) -> bool +{ + py::object method = PyTorch_LookupSpecial(obj, "__torch_function__"); + if(method.ptr() != nullptr){ + return true; + } + return false; +} + } // namespace torch diff --git a/torch/functional.py b/torch/functional.py index 6ed6942e79c384..4620535645a2a9 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -3,6 +3,8 @@ from torch._six import inf from itertools import product +from ._overrides import torch_function_dispatch + __all__ = [ 'align_tensors', # BUILD_NAMEDTENSOR only 'broadcast_tensors', @@ -22,7 +24,10 @@ 'unique_consecutive', ] +def _broadcast_tensors_dispatcher(*tensors): + return tensors +@torch_function_dispatch(_broadcast_tensors_dispatcher) def broadcast_tensors(*tensors): r"""broadcast_tensors(*tensors) -> List of Tensors @@ -52,6 +57,11 @@ def broadcast_tensors(*tensors): return torch._C._VariableFunctions.broadcast_tensors(tensors) +def _split_dispatcher(tensor, split_size_or_sections, dim=0): + return (tensor,) + + +@torch_function_dispatch(_split_dispatcher) def split(tensor, split_size_or_sections, dim=0): r"""Splits the tensor into chunks. @@ -168,6 +178,11 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): return P, L, U +def _einsum_dispatcher(equation, *operands): + return operands + + +@torch_function_dispatch(_einsum_dispatcher) def einsum(equation, *operands): r"""einsum(equation, *operands) -> Tensor @@ -241,6 +256,11 @@ def einsum(equation, *operands): return torch._C._VariableFunctions.einsum(equation, operands) +def _isinf_dispatcher(tensor): + return (tensor,) + + +@torch_function_dispatch(_isinf_dispatcher) def isinf(tensor): r"""Returns a new tensor with boolean elements representing if each element is `+/-INF` or not. @@ -262,6 +282,11 @@ def isinf(tensor): return tensor.abs() == inf +def _meshgrid_dispatcher(*tensors, **kwargs): + return tensors + + +@torch_function_dispatch(_meshgrid_dispatcher) def meshgrid(*tensors, **kwargs): r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by @@ -299,6 +324,13 @@ def meshgrid(*tensors, **kwargs): return torch._C._VariableFunctions.meshgrid(tensors) +def _stft_dispatcher(input, n_fft, hop_length=None, win_length=None, + window=None, center=True, pad_mode='reflect', + normalized=False, onesided=True): + return (input,) + + +@torch_function_dispatch(_stft_dispatcher) def stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor @@ -395,7 +427,12 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, del torch.unique_dim +def _unique_dispatcher(input, sorted=None, return_inverse=None, + return_counts=None, dim=None): + return (input,) + +@torch_function_dispatch(_unique_dispatcher) def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): r"""Returns the unique elements of the input tensor. @@ -479,7 +516,11 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No else: return output +def _unique_consecutive_dispatcher( + input, return_inverse=None, return_counts=None, dim=None): + return (input,) +@torch_function_dispatch(_unique_consecutive_dispatcher) def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None): r"""Eliminates all but the first element from every consecutive group of equivalent elements. @@ -540,7 +581,11 @@ def unique_consecutive(input, return_inverse=False, return_counts=False, dim=Non return output, counts return output +def _tensordot_dispatcher(a, b, dims=None): + return (a, b) + +@torch_function_dispatch(_tensordot_dispatcher) def tensordot(a, b, dims=2): r"""Returns a contraction of a and b over multiple dimensions. @@ -595,7 +640,7 @@ def tensordot(a, b, dims=2): dims_b = list(range(dims)) return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b) - +@torch_function_dispatch(_broadcast_tensors_dispatcher) def cartesian_prod(*tensors): """Do cartesian product of the given sequence of tensors. The behavior is similar to python's `itertools.product`. @@ -626,6 +671,10 @@ def cartesian_prod(*tensors): """ return torch._C._VariableFunctions.cartesian_prod(tensors) +def _cdist_dispatcher(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): + return (x1, x2) + +@torch_function_dispatch(_cdist_dispatcher) def cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. @@ -677,6 +726,11 @@ def cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): raise ValueError("{} is not a valid value for compute_mode".format(compute_mode)) +def _norm_dispatcher(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): + return (input,) + + +@torch_function_dispatch(_norm_dispatcher) def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): r"""Returns the matrix norm or vector norm of a given tensor. @@ -774,6 +828,11 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out) +def _chain_matmul_dispatcher(*matrices): + return matrices + + +@torch_function_dispatch(_chain_matmul_dispatcher) def chain_matmul(*matrices): r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms @@ -805,7 +864,11 @@ def chain_matmul(*matrices): """ return torch._C._VariableFunctions.chain_matmul(matrices) +def _lu_dispatcher(A, pivot=None, get_infos=None, out=None): + return (A,) + +@torch_function_dispatch(_lu_dispatcher) def lu(A, pivot=True, get_infos=False, out=None): r"""Computes the LU factorization of a matrix or batches of matrices :attr:`A`. Returns a tuple containing the LU factorization and