Skip to content

add __torch_function__ API override mechanism #27064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 173 commits into from
Closed
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
173 commits
Select commit Hold shift + click to select a range
e966598
first try
prasunanand Aug 17, 2019
cda2057
modify dispatcher
prasunanand Aug 19, 2019
9f710a4
signatures matched
prasunanand Aug 19, 2019
fee5ef1
Fix mixed tabs/spaces
rgommers Aug 19, 2019
1bc13fd
Add implement_torch_function (in Python) implementation.
rgommers Aug 19, 2019
f9ea2ac
Move code from torch/__init__.py to torch/_overrides.py
rgommers Aug 19, 2019
5965063
Remove TORCH_FUNCTION_ENABLED, this is an env var we don't need
rgommers Aug 19, 2019
6ea9718
Fix flake8 warnings
rgommers Aug 19, 2019
8982470
Add TODO for temporary addition to torch/__init__.py
rgommers Aug 19, 2019
04af379
Add some imports, comments, and dummy __torch_function__
rgommers Aug 19, 2019
4c05cfc
Implement __torch_function__ in Python.
rgommers Aug 20, 2019
8e80113
Add an example of using the override for gemm in test/test_overrides.py
rgommers Aug 20, 2019
4b4ca02
Add ASV benchmarks. Also fix an issue with Tensor.__torch_function__
rgommers Aug 20, 2019
c3e8731
Add some documentation for writing and running ASV benchmarks
rgommers Aug 20, 2019
3753da7
Add a few overloads, and adds docs on what dispatcher functions should
rgommers Aug 20, 2019
da6ed8e
Another documentation tweak.
rgommers Aug 20, 2019
47f4703
adopt tests from numpy
prasunanand Aug 22, 2019
fb8986e
modify the unittests and mark a few of them to be skipped
prasunanand Aug 26, 2019
fce2ef8
correct type order for subclass tests
prasunanand Aug 26, 2019
cb22ee1
Remove `assert_` again in favor of plain `assert`
rgommers Aug 28, 2019
3a35770
Remove `TORCH_FUNCTION_ENABLED`, it wasn't doing anything.
rgommers Aug 28, 2019
70e9e29
Remove torch.gemm, and change test to use torch.unique
rgommers Aug 28, 2019
c361848
Fix __torch_function__ subclass ordering test
rgommers Aug 28, 2019
a70d13b
Remove irrelevant test for `__torch_function__`
rgommers Sep 4, 2019
e0fedb9
Fix a couple more `__torch_function__` tests.
rgommers Sep 4, 2019
20e535e
Fix an issue with subclasses for `__torch_function__`.
rgommers Sep 4, 2019
b757e11
Fix one more test, and remove an unnecessary test
rgommers Sep 4, 2019
26885eb
Fix last failing test for `__torch_function__`.
rgommers Sep 4, 2019
fb71ad0
Treat imports properly in override benchmarks.
rgommers Sep 4, 2019
58de6bb
Fix some typos in comments
rgommers Sep 4, 2019
29bdccc
Make flake8 happy
rgommers Sep 4, 2019
f012e5b
Remove a spurious tab in asv.conf.json
rgommers Sep 4, 2019
d9dbf16
Fix two more pylint issues
rgommers Sep 4, 2019
a344329
Torch function overrides in cpp
prasunanand Sep 5, 2019
8c7b2d4
Added comments to the code
prasunanand Sep 24, 2019
7537bb2
Parse modified
prasunanand Sep 24, 2019
6cae1b9
Test for NN
prasunanand Sep 25, 2019
36fbd71
Skip overrides test
prasunanand Sep 25, 2019
08ecb78
Added comments to parse and removed some unused code
prasunanand Sep 26, 2019
8335d41
Find attribute only if present
prasunanand Sep 26, 2019
e05da42
Get rid of Python code and modify overrides tests
prasunanand Sep 27, 2019
2556ded
Parse works! Remove duplicate code
prasunanand Sep 27, 2019
fe954b2
Python 2 support
prasunanand Oct 1, 2019
faa517a
Lint fix: Add new line
prasunanand Oct 1, 2019
145309b
Fix Clang tidy error: {nullptr}
prasunanand Oct 1, 2019
56498fc
Flake errors fix
prasunanand Oct 1, 2019
53c0f21
Flake errors fix
prasunanand Oct 1, 2019
a69abe4
Check for overheads on add and multiply
prasunanand Oct 7, 2019
9cd4ef7
Add __torch_function__ to other Torch APIs
prasunanand Oct 7, 2019
8783bb2
Rebase with master
prasunanand Oct 7, 2019
c98161e
Modify the benchmark code
prasunanand Oct 8, 2019
8242cfd
Subclass of torch.Tensor should check for __torch_function__
prasunanand Oct 8, 2019
4b0468c
Benchmark SubTensors with __torch_function__ defined
prasunanand Oct 8, 2019
175a118
Minor tweaks for lesser overhead
prasunanand Oct 8, 2019
298bad4
Handle subclasses of Tensor, fix test
prasunanand Oct 10, 2019
de139c6
Fix Python3.7 Lint errors
prasunanand Oct 10, 2019
2931eb0
Fix Python Lint errors
prasunanand Oct 10, 2019
af89388
Fix clang tidy
prasunanand Oct 10, 2019
b63f58b
Test overrides of Torch public APIs
prasunanand Oct 12, 2019
5f87cc2
More Test overrides of Torch public APIs
prasunanand Oct 12, 2019
8dfabca
make args default to () in __torch_function__ implementations used fo…
ngoldbaum Oct 15, 2019
6496500
Fix duplicate function name
ngoldbaum Oct 15, 2019
1ab345b
reduce boilerplate in override tests by defining ImplementationMeta m…
ngoldbaum Oct 15, 2019
c3465f8
Merge pull request #1 from ngoldbaum/torch_function
prasunanand Oct 16, 2019
5a8a2c7
Merge branch 'master' into torch_function
ngoldbaum Oct 16, 2019
df60370
add support for torch functions defined in python
ngoldbaum Oct 17, 2019
b949180
autogenerate tests for the full torch API
ngoldbaum Oct 23, 2019
b67c39e
add override tests for some more functions
ngoldbaum Oct 23, 2019
89695d0
Merge pull request #2 from ngoldbaum/torch_function
prasunanand Oct 24, 2019
f547e73
include test_overrides in main test runner
ngoldbaum Oct 24, 2019
1b612a6
Move helpers from python_variable.h to python_arg_parser.h
ngoldbaum Oct 24, 2019
509a05d
fix python2.7 SyntaxError
ngoldbaum Oct 24, 2019
81a9aed
appease clang-tidy
ngoldbaum Oct 24, 2019
f687f15
Merge remote-tracking branch 'prasun/torch_function' into torch_function
ngoldbaum Oct 25, 2019
271d94d
use functools.wraps for the dispatch decorators
ngoldbaum Oct 25, 2019
6f0da73
rename HANDLED_FUNCTIONS to HANDLED_FUNCTIONS_DIAGONAL
ngoldbaum Oct 25, 2019
630a155
remove ImplementationMeta to make tests follow suggested implementation
ngoldbaum Oct 25, 2019
c391bc4
reorganize so that dispatch tables and dispatch decorators are groupe…
ngoldbaum Oct 25, 2019
d0de54c
expand comments explaining dispatch tables
ngoldbaum Oct 25, 2019
a2b046e
expand comments
ngoldbaum Oct 25, 2019
604c87c
use only one TestCase subclass
ngoldbaum Oct 25, 2019
9687422
make override tests runnable in pytest
ngoldbaum Oct 25, 2019
d6c852b
Add comments and small corrections
prasunanand Oct 28, 2019
db6ddf2
Reference to Numpy and minor edits related to review
prasunanand Oct 28, 2019
d5c9eb8
Remove check_exact, instead parse a boolean
prasunanand Oct 28, 2019
8b3741e
remove usage of getargspec from numpy
ngoldbaum Oct 29, 2019
a82499b
make it clearer that everything besides torch_function_dispatch is pr…
ngoldbaum Oct 29, 2019
5e98d3b
remove unused docs_from_dispatcher keyword for torch_function_dispatch
ngoldbaum Oct 29, 2019
f96f94a
expand docs for torch_function_dispatch decorator
ngoldbaum Oct 29, 2019
baae730
add py2/py3 compat code in test_overrides
ngoldbaum Oct 29, 2019
9102244
simplify testing somewhat
ngoldbaum Oct 29, 2019
eb25705
fix spelling and rst syntax
ngoldbaum Oct 29, 2019
7e150ac
Merge branch 'master' into torch_function
ngoldbaum Oct 29, 2019
c472d48
bring back keyword arguments for TensorLike API override tests
ngoldbaum Oct 29, 2019
e710eb2
add new keyword argument to cdist
ngoldbaum Oct 29, 2019
2b9064b
pass the function object instead of the name to __torch_function__
ngoldbaum Oct 30, 2019
bc67494
remove unused _torch_function function in torch.tensor module
ngoldbaum Oct 30, 2019
db06350
remove unused python bindings for _promote_types
ngoldbaum Oct 30, 2019
16c4122
remove unnecessary __torch_function__ checking for functions that can…
ngoldbaum Oct 30, 2019
30936bc
make torch.numel overridable
ngoldbaum Oct 30, 2019
b8b783e
simplify argument parsing logic and add explanatory comments
ngoldbaum Oct 30, 2019
92ce07a
remove breakpoint
ngoldbaum Nov 1, 2019
f1cf69d
make overloaded_args a std::vector to remove signature size limit
ngoldbaum Nov 1, 2019
2166bb1
don't initialize overloaded_args with 32 nullptr
ngoldbaum Nov 2, 2019
fd4f3d7
remove unnecessary returns
ngoldbaum Nov 2, 2019
a0faf8d
remove PythonArgs::get_overload_arg
ngoldbaum Nov 2, 2019
efe5d2e
update comment wording
ngoldbaum Nov 2, 2019
9f0a146
check if __torch_function__ returns NotImplemented and call next-high…
ngoldbaum Nov 4, 2019
52f1b5b
expand tests for __torch_function__ return semantics
ngoldbaum Nov 4, 2019
f39297c
Merge branch 'master' into torch_function
ngoldbaum Nov 4, 2019
6893a11
fix clang_tidy nit
ngoldbaum Nov 4, 2019
e096ef3
update asv benchmarks and benchmark docs
ngoldbaum Nov 4, 2019
b82cb5c
expand benchmarks
ngoldbaum Nov 5, 2019
a0aacf0
reduce code duplication in the code generation
ngoldbaum Nov 5, 2019
78b6f87
fix indentation
ngoldbaum Nov 5, 2019
bcdcf67
remove unused get_tensor_torch_function
ngoldbaum Nov 5, 2019
dc781f7
ignore torch.sparse_coo_tensor for overriding purposes
ngoldbaum Nov 5, 2019
00f3035
fix indentation
ngoldbaum Nov 5, 2019
377b6d1
explicitly check __torch_function__ doesn't return nullptr
ngoldbaum Nov 5, 2019
f10ba9f
use a range-based for loop to simplify handle_torch_function
ngoldbaum Nov 5, 2019
a4cd209
remove unused testing code
ngoldbaum Nov 6, 2019
08d6d9f
fix reference counting in handle_torch_function
ngoldbaum Nov 6, 2019
11525b1
remove unused get_torch_function
ngoldbaum Nov 6, 2019
fa37c21
move check_has_torch_function to the arg parser header
ngoldbaum Nov 6, 2019
bcb498f
fix reference leak in check_has_torch_function
ngoldbaum Nov 6, 2019
994b83e
simplify logic in PythonArgParse::parse
ngoldbaum Nov 6, 2019
5f44700
do reference counting for objects in overloaded_args vector
ngoldbaum Nov 6, 2019
c00fa4d
expand docs of new C-level helper functions
ngoldbaum Nov 6, 2019
a4651b5
revert added whitespace
ngoldbaum Nov 6, 2019
6de2423
combine THPVariable_Check and THPVariable_CheckExact to simplify logi…
ngoldbaum Nov 6, 2019
8ad8070
update tests to check for TypeError
ngoldbaum Nov 8, 2019
0892635
add doc comment for handle_torch_function
ngoldbaum Nov 8, 2019
a622a0f
refactor to use pybind11 wrappers, make handle_torch_function raise e…
ngoldbaum Nov 8, 2019
fc45687
ensure exceptions raised in user implementations are propagated
ngoldbaum Nov 8, 2019
0cf2227
cross-reference python and C++ implementations of __torch_function__ …
ngoldbaum Nov 8, 2019
7a2f055
fix deprecation warning
ngoldbaum Nov 8, 2019
e2ec5a2
refactor overloaded_args handling in parser into a helper function
ngoldbaum Nov 8, 2019
d596d56
use range-based for loops
ngoldbaum Nov 8, 2019
724d248
expand documentation for torch_override helper functions
ngoldbaum Nov 8, 2019
14b479b
fix review nits in python code
ngoldbaum Nov 8, 2019
1265134
fix compiler error
ngoldbaum Nov 8, 2019
64b1962
Merge branch 'master' into torch_function
ngoldbaum Nov 8, 2019
14b1634
use the default py::object initializer for the return value of handle…
ngoldbaum Nov 11, 2019
e28d602
use reinterpret_steal instead of reinterpret_borrow to make reference…
ngoldbaum Nov 11, 2019
00cb4e7
explicitly throw errors using python_error()
ngoldbaum Nov 11, 2019
c70c7d1
refactor tests to explicitly compare python and C++ dispatch
ngoldbaum Nov 11, 2019
8eced71
remove unnecessary comments
ngoldbaum Nov 12, 2019
98da931
reword docstring for precedence tests
ngoldbaum Nov 12, 2019
a3a98e7
add explanatory comment to overrides tests
ngoldbaum Nov 12, 2019
c55c2f1
attempt to reduce branching using templates
ngoldbaum Nov 14, 2019
9717cf7
Merge branch 'master' into torch_function
ngoldbaum Nov 14, 2019
c36da88
add bitwise_xor test
ngoldbaum Nov 14, 2019
082bd6d
use a template parameter instead of partial specialization to reduce …
ngoldbaum Nov 14, 2019
b9bfb17
use release() instead of incref() to avoid changing the reference cou…
ngoldbaum Nov 14, 2019
e2286f3
add a check for tensor types in PyTorch_LookupSpecial
ngoldbaum Nov 15, 2019
274c59d
only do exact checking on tensor operands
ngoldbaum Nov 15, 2019
ec804c4
refactor to move overloaded_args to FunctionSignature
ngoldbaum Nov 15, 2019
1a1fab6
remove unnecessary parens
ngoldbaum Nov 20, 2019
d38f2c3
add documentation for __torch_function__
ngoldbaum Nov 20, 2019
0be6167
Merge branch 'master' into torch_function
ngoldbaum Nov 20, 2019
bc6d00b
Merge remote-tracking branch 'origin/master' into torch_function
ezyang Nov 20, 2019
0aaf456
add PYBIND11_EXPORT to FunctionSignature
ngoldbaum Nov 20, 2019
ff7d3c0
update expected answers for ONNX tests
ngoldbaum Nov 20, 2019
99b0412
add hyperlinks for numpy's __array_function__ to docs
ngoldbaum Nov 20, 2019
903df1d
move __torch_function__ documentation to the end of notes/extending.rst
ngoldbaum Nov 20, 2019
7571dc8
make class definitions more copy/pasteable
ngoldbaum Nov 20, 2019
d5ebcdc
reword intro section
ngoldbaum Nov 20, 2019
ae791ca
add note that HANDLED_FUNCTIONS pattern isn't required
ngoldbaum Nov 20, 2019
488d55a
rewording
ngoldbaum Nov 20, 2019
07793ce
respond to doc comments
ngoldbaum Nov 21, 2019
403d45c
Merge remote-tracking branch 'upstream/master' into torch_function
ngoldbaum Nov 21, 2019
3e72316
Merge remote-tracking branch 'origin/master' into torch_function
ezyang Dec 3, 2019
d2d9c12
remove asv benchmarks
ngoldbaum Dec 3, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,63 @@ Please refer to each subfolder to discover each benchmark suite

* [Fast RNNs benchmarks](fastrnns/README.md)



## PyTorch ASV benchmarks

Benchmarking PyTorch with Airspeed Velocity.


Usage
-----

Airspeed Velocity manages building and Python virtualenvs or conda envs by
itself, unless told otherwise (e.g. with `--python=same`).
To run the benchmarks, you do not need to install a development version of
PyTorch to your current Python environment.
TODO: check that the isolated build feature works, so far just used
`--python=same`.

Run a benchmark against currently installed PyTorch version (don't
record the result)::

asv run --python=same

Compare change in benchmark results to another version::

TODO

Run ASV commands (record results and generate HTML)::

cd benchmarks
asv run --skip-existing-commits --steps 10 ALL
asv publish
asv preview

More on how to use ``asv`` can be found in the
[ASV documentation](https://asv.readthedocs.io).
Command-line help is available as usual via `asv --help` and `asv run --help`.



Writing benchmarks
------------------

See the ASV documentation for basics on how to write benchmarks.

Some things to consider:

- The benchmark suite should be importable with any PyTorch version.

- The benchmark parameters etc. should not depend on which PyTorch version
is installed.

- Try to keep the runtime of the benchmark reasonable.

- Prefer ASV's `time_` methods for benchmarking times rather than cooking up
time measurements via `time.clock`, even if it requires some juggling when
writing the benchmark.

- Preparing input tensors etc. should generally be put in the `setup` method
rather than the `time_` methods, to avoid counting preparation time together
with the time of the benchmarked operation.
85 changes: 85 additions & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
// The version of the config file format. Do not change, unless
// you know what you are doing.
"version": 1,

// The name of the project being benchmarked
"project": "PyTorch",

// The project's homepage
"project_url": "https://pytorch.org/",

// The URL or local path of the source code repository for the
// project being benchmarked
"repo": "..",

// List of branches to benchmark. If not provided, defaults to "master"
// (for git) or "tip" (for mercurial).
"branches": ["master"],

// The DVCS being used. If not set, it will be automatically
// determined from "repo" by looking at the protocol in the URL
// (if remote), or by looking for special directories, such as
// ".git" (if local).
"dvcs": "git",

// The tool to use to create environments. May be "conda",
// "virtualenv" or other value depending on the plugins in use.
// If missing or the empty string, the tool will be automatically
// determined by looking for tools on the PATH environment
// variable.
"environment_type": "conda",

// the base URL to show a commit for the project.
"show_commit_url": "https://github.com/pytorch/pytorch/commit/",

// The Pythons you'd like to test against. If not provided, defaults
// to the current version of Python used to run `asv`.
"pythons": ["3.6"],

// The matrix of dependencies to test. Each key is the name of a
// package (in PyPI) and the values are version numbers. An empty
// list indicates to just test against the default (latest)
// version.
"matrix": {
"six": [],
},

// The directory (relative to the current directory) that benchmarks are
// stored in. If not provided, defaults to "benchmarks"
"benchmark_dir": "benchmarks",

// The directory (relative to the current directory) to cache the Python
// environments in. If not provided, defaults to "env"
"env_dir": "env",


// The directory (relative to the current directory) that raw benchmark
// results are stored in. If not provided, defaults to "results".
"results_dir": "results",

// The directory (relative to the current directory) that the html tree
// should be written to. If not provided, defaults to "html".
"html_dir": "html",

// The number of characters to retain in the commit hashes.
// "hash_length": 8,

// `asv` will cache wheels of the recent builds in each
// environment, making them faster to install next time. This is
// number of builds to keep, per environment.
"build_cache_size": 2,

// The commits after which the regression search in `asv publish`
// should start looking for regressions. Dictionary whose keys are
// regexps matching to benchmark names, and values corresponding to
// the commit (exclusive) after which to start looking for
// regressions. The default is to start from the first commit
// with results. If the commit is `null`, regression detection is
// skipped for the matching benchmark.
//
// "regressions_first_commits": {
// "some_benchmark": "352cdf", // Consider regressions only after this commit
// "another_benchmark": null, // Skip regression detection altogether
// }
}
1 change: 1 addition & 0 deletions benchmarks/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import, division, print_function
58 changes: 58 additions & 0 deletions benchmarks/benchmarks/bench_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from .common import Benchmark

import torch
from torch import Tensor

class DuckTensor(object):
def __torch_function__(self, func, args=(), kwargs=None):
pass

HANDLED_FUNCTIONS = {}

def implements(torch_function):
"Register an implementation of a torch function for a Tensor-like object."
def decorator(func):
HANDLED_FUNCTIONS[torch_function.__name__] = func
return func
return decorator

class SubTensor(Tensor):
def __torch_function__(self, func, args=(), kwargs=None):
if(kwargs is None):
kwargs = {}

if func not in HANDLED_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __torch_function__ to handle DiagonalTensor objects.
return HANDLED_FUNCTIONS[func](*args, **kwargs)

@implements(torch.add)
def add(mat1, mat2):
"Implementation of torch.mm for DiagonalTensor objects"
return 0

@implements(torch.mm)
def mm(mat1, mat2):
"Implementation of torch.mm for DiagonalTensor objects"
return 1

class TorchFunction(Benchmark):

def setup(self):
self.t1 = torch.ones(2, 2, dtype=torch.float32)
self.t2 = torch.zeros(2, 2, dtype=torch.float32)
self.t3 = SubTensor([[1, 1], [1, 1.]])
self.t4 = SubTensor([[0, 0], [0, 0.]])

def time_add(self):
torch.add(self.t1, self.t2)

def time_matmul(self):
torch.mm(self.t1, self.t2)

def time_subtensor_add(self):
torch.add(self.t3, self.t4)

def time_subtensor_multipy(self):
torch.mm(self.t3, self.t4)
18 changes: 18 additions & 0 deletions benchmarks/benchmarks/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import absolute_import, division, print_function

import random
import torch
import numpy


# Better not to use random numbers for benchmarks, but just in case,
# seed everything
random.seed(123123459)
torch.manual_seed(123123459)
numpy.random.seed(123123459)


class Benchmark(object):
# asv auto-selects number of iterations, so benchmarks runs in between
# goal_time/10 and goal_time (seconds)
goal_time = 0.25
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
'type_promotion',
'jit_disabled',
'function_schema',
'overrides',
]

# skip < 3.3 because mock is added in 3.3 and is used in rpc_fork and rpc_spawn
Expand Down
Loading