Skip to content

Commit aaee5c9

Browse files
ngoldbaumttumiel
authored andcommitted
__torch_function__ overrides for torch.functional and torch.nn.functional (pytorch#32799)
Summary: This adds `__torch_function__` support for all functions in `torch.functional` and `torch.nn.functional`. The changes to C++ code and codegen scripts are to facilitate adding `__torch_function__` support for the native functions in `torch._C._nn`. Note that I moved the `handle_torch_function` C++ function to a header that both `python_torch_functions.cpp` and `python_nn_functions.cpp` include. The changes to `python_nn_functions.cpp` mirror the changes I made to `python_torch_functions.cpp` when `__torch_function__` support was first added in pytorch#27064. Due to the somewhat different way the `torch._C` and `torch._C._nn` namespaces are initialized I needed to create a new static reference to the `torch._C._nn` namespace (`THPNNVariableFunctions`). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost. I added `__torch_function__` support to the Python functions in `torch.nn.functional` following the approach in pytorch#32194. I re-enabled the test that checks if all functions in the `torch` namespace are explicitly tested for `__torch_function__` support. I also generalized the check to work for `torch.functional` and `torch.nn.functional` as well. This test was explicitly disabled in pytorch#30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it. Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in `torch.nn.functional` that were missed when I originally added the tests in pytorch#27064. Pull Request resolved: pytorch#32799 Differential Revision: D19956809 Pulled By: ezyang fbshipit-source-id: 40d34e0109cc4b9f3ef62f409d2d35a1d84e3d22
1 parent 9709407 commit aaee5c9

File tree

8 files changed

+769
-169
lines changed

8 files changed

+769
-169
lines changed

test/test_overrides.py

Lines changed: 182 additions & 34 deletions
Large diffs are not rendered by default.

tools/autograd/gen_python_functions.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@
8383
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
8484
]
8585

86+
NATIVE_NAMESPACE_MAPPING = {
87+
"torch": "THPVariableFunctionsModule",
88+
"torch.nn": "THPNNVariableFunctionsModule"
89+
}
90+
8691
def should_generate_python_binding(declaration):
8792
name = declaration['name']
8893
for pattern in SKIP_PYTHON_BINDINGS:
@@ -122,7 +127,8 @@ def gen_py_variable_methods(out, declarations, template_path):
122127

123128
py_variable_methods = get_py_variable_methods(declarations)
124129

125-
env = create_python_bindings(py_variable_methods, is_python_method=True, is_module=False)
130+
env = create_python_bindings(py_variable_methods, is_python_method=True, module=None)
131+
126132
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
127133

128134

@@ -146,7 +152,8 @@ def gen_py_nn_functions(out, declarations, template_path):
146152

147153
py_nn_functions = get_py_nn_functions(declarations)
148154

149-
env = create_python_bindings(py_nn_functions, is_python_method=False, is_module=True)
155+
env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn")
156+
150157
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
151158

152159

@@ -171,7 +178,8 @@ def gen_py_torch_functions(out, declarations, template_path):
171178

172179
py_torch_functions = get_py_torch_functions(declarations)
173180

174-
env = create_python_bindings(py_torch_functions, is_python_method=False, is_module=False)
181+
env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch")
182+
175183
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
176184

177185

@@ -182,17 +190,17 @@ def group_declarations_by_op_name(declarations):
182190
return groups
183191

184192

185-
def create_python_bindings(python_functions, is_python_method, is_module):
193+
def create_python_bindings(python_functions, is_python_method, module):
186194
"""Generates Python bindings to ATen functions"""
187195
py_methods = []
188196
py_method_defs = []
189197
py_forwards = []
190198

191199
for name in sorted(python_functions.keys()):
192200
overload_decls = python_functions[name]
193-
py_methods.append(method_impl(name, overload_decls, is_python_method, is_module))
194-
py_method_defs.append(method_def(name, overload_decls, is_python_method, is_module))
195-
py_forwards.extend(forward_decls(name, overload_decls, is_python_method, is_module))
201+
py_methods.append(method_impl(name, overload_decls, is_python_method, module))
202+
py_method_defs.append(method_def(name, overload_decls, is_python_method, module))
203+
py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module))
196204

197205
return {
198206
'py_forwards': py_forwards,
@@ -714,7 +722,6 @@ def get_field_name(x):
714722
return x['field_name']
715723
return [get_field_name(x) for x in returns]
716724

717-
718725
PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\
719726
static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} };
720727
""")
@@ -841,19 +848,19 @@ def is_noarg_binding(overloads):
841848
842849
""")
843850

844-
TORCH_FUNCTION_CHECK = """\
845-
if (_r.has_torch_function()) {
846-
return handle_torch_function(_r, args, kwargs, THPVariableFunctions);
851+
TORCH_FUNCTION_CHECK = CodeTemplate("""\
852+
if(_r.has_torch_function()) {
853+
return handle_torch_function(_r, args, kwargs, ${namespace}, ${modulename});
847854
}
848-
"""
855+
""")
849856

850857
# NOTE: we type the unpacked self as Tensor not Variable to avoid return type
851858
# discrepancies on method resolution (e.g. Variable::detach_ returns void
852859
# rather than Tensor &)
853860
UNPACK_SELF = "Tensor& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
854861

855862

856-
def method_impl(name, declarations, is_python_method, is_module):
863+
def method_impl(name, declarations, is_python_method, module):
857864
"""
858865
Generate a python binding for all overloads of an op.
859866
"""
@@ -898,8 +905,11 @@ def method_impl(name, declarations, is_python_method, is_module):
898905
else:
899906
template = PY_VARIABLE_METHOD_VARARGS
900907

901-
if not is_module and not is_python_method:
902-
check_has_torch_function = TORCH_FUNCTION_CHECK
908+
if module:
909+
check_has_torch_function = TORCH_FUNCTION_CHECK.substitute(
910+
namespace=NATIVE_NAMESPACE_MAPPING[module],
911+
modulename='"' + module + '"',
912+
)
903913
else:
904914
check_has_torch_function = ''
905915

@@ -932,8 +942,8 @@ def method_impl(name, declarations, is_python_method, is_module):
932942
""")
933943

934944

935-
def forward_decls(name, declarations, is_python_method, is_module):
936-
if is_module or is_python_method:
945+
def forward_decls(name, declarations, is_python_method, module):
946+
if is_python_method:
937947
return []
938948

939949
if is_noarg_binding(declarations):
@@ -980,7 +990,7 @@ def forward_decls(name, declarations, is_python_method, is_module):
980990
{"${name}", (PyCFunction)${pycfunc_voidcast}${pycname}, ${flags}, NULL},""")
981991

982992

983-
def method_def(name, declarations, is_python_method, is_module):
993+
def method_def(name, declarations, is_python_method, module):
984994
"""
985995
Generate method def entry.
986996
"""
@@ -993,7 +1003,7 @@ def method_def(name, declarations, is_python_method, is_module):
9931003
pycfunc_voidcast = '(void(*)(void))'
9941004
flags = 'METH_VARARGS | METH_KEYWORDS'
9951005

996-
if not is_module and not is_python_method:
1006+
if module == "torch":
9971007
flags += ' | METH_STATIC'
9981008

9991009
if name in BINARY_OP_NAMES:

tools/autograd/templates/python_nn_functions.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,18 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje
5353
END_HANDLE_TH_ERRORS
5454
}
5555

56-
${py_methods}
56+
// generated forward declarations start here
57+
58+
${py_forwards}
5759

5860
static PyMethodDef nn_functions[] = {
5961
{"_parse_to", (PyCFunction)(void(*)(void))THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr},
6062
${py_method_defs}
6163
{NULL}
6264
};
6365

66+
static PyObject* THPNNVariableFunctionsModule = NULL;
67+
6468
void initNNFunctions(PyObject* module) {
6569
#if PY_MAJOR_VERSION == 2
6670
PyObject* nn = Py_InitModule("torch._C._nn", nn_functions);
@@ -75,6 +79,7 @@ void initNNFunctions(PyObject* module) {
7579
};
7680
PyObject* nn = PyModule_Create(&def);
7781
#endif
82+
THPNNVariableFunctionsModule = nn;
7883
if (!nn) {
7984
throw python_error();
8085
}
@@ -84,4 +89,8 @@ void initNNFunctions(PyObject* module) {
8489
}
8590
}
8691

92+
// generated methods start here
93+
94+
${py_methods}
95+
8796
}} // namespace torch::autograd

tools/autograd/templates/python_torch_functions.cpp

Lines changed: 8 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -450,88 +450,19 @@ static PyTypeObject THPVariableFunctions = {
450450
0 /* tp_new */
451451
};
452452

453+
static PyObject* THPVariableFunctionsModule = NULL;
454+
453455
void initTorchFunctions(PyObject* module) {
454456
if (PyType_Ready(&THPVariableFunctions) < 0) {
455457
throw python_error();
456458
}
457459
Py_INCREF(&THPVariableFunctions);
458-
if (PyModule_AddObject(module, "_VariableFunctions", (PyObject*)&THPVariableFunctions) < 0) {
459-
throw python_error();
460-
}
461-
}
462-
463-
/*
464-
*
465-
* Calls __torch_function__ on the overloaded arguments to a torch API
466-
* function in order of precedence, returning the first result that is
467-
* not NotImplemented. If all arguments return NotImplemented, raises a
468-
* TypeError.
469-
*
470-
* Assumes overloaded_args has at least one entry. All entries must have
471-
* a __torch_function__ attribute that resolves to a callable that
472-
* accepts a torch API function, arguments, and keyword arguments for
473-
* the torch API function.
474-
*
475-
* It is sufficient to call PythonArgs::has_torch_function before
476-
* calling this function to verify that there are valid arguments
477-
* present. If that is not done then special care must be taken to
478-
* ensure there are arguments that are overloaded with
479-
* __torch_function__.
480-
*
481-
* See torch._overrides._implement_torch_function for the equivalent
482-
* code in the pure-python implementation.
483-
*
484-
* 'r' is a parsed PythonArgs instance, returned from
485-
* PythonArgParser::parse.
486-
*
487-
* 'args' is a reference to the python tuple of arguments to the torch
488-
* API function.
489-
*
490-
* 'kwargs' is a reference to the python dict of keyword arguments to
491-
* the torch API function.
492-
*
493-
* 'torch_api' is a reference to python torch API namespace.
494-
*
495-
*/
496-
497-
PyObject* handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyTypeObject &torch_api) {
498-
py::object torch_api_function = PyObject_FastGetAttrString((PyObject*)&torch_api, const_cast<char*>(r.get_func_name().data()));
499-
TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != NULL, "torch API function must exist");
500-
py::object ret;
501-
for (auto &arg : r.signature.overloaded_args) {
502-
py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__");
503-
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), args, kwargs, NULL));
504-
if (ret.ptr() != Py_NotImplemented) {
505-
// Return the reference to the result. This also covers the case where ret
506-
// is NULL and __torch_function__ raised an exception, which we throw below
507-
break;
508-
}
509-
}
510-
if (ret.ptr() == nullptr) {
511-
// if an exception occurred in a user's implementation of
512-
// __array_function__, throw it
513-
throw python_error();
514-
}
515-
else if (ret.ptr() == Py_NotImplemented) {
516-
// all __torch_function__ implementations in overloaded_args
517-
// returned NotImplemented, so we raise a TypeError.
518-
std::stringstream ss;
519-
ss << "no implementation found for 'torch." << r.get_func_name()
520-
<< "' on types that implement __torch_function__: [";
521-
for (auto &arg : r.signature.overloaded_args) {
522-
ss << arg.ptr()->ob_type->tp_name;
523-
if (!arg.is(r.signature.overloaded_args.back())) {
524-
ss << ", ";
525-
}
526-
else {
527-
ss << "]";
528-
}
529-
}
530-
const std::string& tmp = ss.str();
531-
PyErr_SetString(PyExc_TypeError, tmp.c_str());
460+
// PyType_GenericNew returns a new reference
461+
THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
462+
// PyModule_AddObject steals a reference
463+
if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
532464
throw python_error();
533465
}
534-
return ret.release().ptr();
535466
}
536467

537468
// generated methods start here
@@ -549,7 +480,7 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject*
549480
auto r = parser.parse(args, kwargs, parsed_args);
550481

551482
if(r.has_torch_function()){
552-
return handle_torch_function(r, args, kwargs, THPVariableFunctions);
483+
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
553484
}
554485

555486
if (r.idx == 0) {
@@ -579,7 +510,7 @@ static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* k
579510
auto r = parser.parse(args, kwargs, parsed_args);
580511

581512
if(r.has_torch_function()){
582-
return handle_torch_function(r, args, kwargs, THPVariableFunctions);
513+
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
583514
}
584515

585516
if (r.idx == 0) {

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,46 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
131131
}
132132
}
133133

134+
auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* {
135+
py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)r.get_func_name().c_str());
136+
TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist");
137+
py::object ret;
138+
for (auto &arg : r.signature.overloaded_args) {
139+
py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__");
140+
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), args, kwargs, NULL));
141+
if (ret.ptr() != Py_NotImplemented) {
142+
// Return the reference to the result. This also covers the case where ret
143+
// is NULL and __torch_function__ raised an exception, which we throw below
144+
break;
145+
}
146+
}
147+
if (ret.ptr() == nullptr) {
148+
// if an exception occurred in a user's implementation of
149+
// __array_function__, throw it
150+
throw python_error();
151+
}
152+
else if (ret.ptr() == Py_NotImplemented) {
153+
// all __torch_function__ implementations in overloaded_args
154+
// returned NotImplemented, so we raise a TypeError.
155+
std::stringstream ss;
156+
ss << "no implementation found for '" << module_name << "." << r.get_func_name()
157+
<< "' on types that implement __torch_function__: [";
158+
for (auto &arg : r.signature.overloaded_args) {
159+
ss << arg.ptr()->ob_type->tp_name;
160+
if (!arg.is(r.signature.overloaded_args.back())) {
161+
ss << ", ";
162+
}
163+
else {
164+
ss << "]";
165+
}
166+
}
167+
const std::string& tmp = ss.str();
168+
PyErr_SetString(PyExc_TypeError, tmp.c_str());
169+
throw python_error();
170+
}
171+
return ret.release().ptr();
172+
}
173+
134174
/*
135175
* obj has a __torch_function__ implementation and may either be a
136176
* subclass of Tensor or a Tensor-like duck type. We may need to

torch/csrc/utils/python_arg_parser.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,4 +676,47 @@ static auto check_has_torch_function(PyObject* obj) -> bool
676676
return false;
677677
}
678678

679+
/*
680+
*
681+
* Handle __torch_function__ overrides if we know that there are overloaded
682+
* arguments. All objects stored in r.overloaded_args must have a
683+
* __torch_function__ implementation and the arguments must be ordered in order
684+
* of precedence. Precedence goes from left to right in the order of the
685+
* signature of the function the overloaded arguments were passed to, except
686+
* subclasses are always considered before superclasses.
687+
*
688+
* If the result of calling __torch_function__ is NotImplemented, the
689+
* next implementation in the precedence order is called. If all
690+
* arguments return NotImplemented from their __torch_function__
691+
* implementation, a TypeError is raised in Python.
692+
*
693+
* Assumes overloaded_args has at least one entry. All entries must have
694+
* a __torch_function__ attribute that resolves to a callable that
695+
* accepts a torch API function, a tuple of arguments, and a dict of
696+
* keyword arguments for the torch API function.
697+
*
698+
* It is sufficient to call PythonArgs::has_torch_function before
699+
* calling this function to verify that there are valid arguments
700+
* present. If that is not done then special care must be taken to
701+
* ensure there are arguments that are overloaded with
702+
* __torch_function__.
703+
*
704+
* See torch._overrides.handle_torch_function for the equivalent
705+
* code in the pure-python implementation.
706+
*
707+
* 'r' is a parsed PythonArgs instance, returned from
708+
* PythonArgParser::parse.
709+
*
710+
* 'args' is a reference to the python tuple of arguments to the torch
711+
* API function.
712+
*
713+
* 'kwargs' is a reference to the python dict of keyword arguments to
714+
* the torch API function.
715+
*
716+
* 'torch_api' is a reference to a python torch API namespace.
717+
*
718+
*/
719+
720+
auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*;
721+
679722
} // namespace torch

0 commit comments

Comments
 (0)