From a0da2bc01e3d5698e0ed1daf727f4c800c9f5d2c Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Sun, 26 Feb 2017 18:03:00 -0500 Subject: [PATCH] array_t overload resolution support This makes array_t respect overload resolution and noconvert by failing to load when `convert = false` if the src isn't already an array of the correct type. --- include/pybind11/numpy.h | 4 ++- tests/test_numpy_array.cpp | 28 +++++++++++++++++++ tests/test_numpy_array.py | 55 +++++++++++++++++++++++++++++++++++++- 3 files changed, 85 insertions(+), 2 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index ea9914a48a..74f58d1a5e 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -688,7 +688,9 @@ template struct pyobject_caster> { using type = array_t; - bool load(handle src, bool /* convert */) { + bool load(handle src, bool convert) { + if (!convert && !type::check_(src)) + return false; value = type::ensure(src); return static_cast(value); } diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 23da916951..58a20524b2 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -151,6 +151,34 @@ test_initializer numpy_array([](py::module &m) { ); }); + // Overload resolution tests: + sm.def("overloaded", [](py::array_t) { return "double"; }); + sm.def("overloaded", [](py::array_t) { return "float"; }); + sm.def("overloaded", [](py::array_t) { return "int"; }); + sm.def("overloaded", [](py::array_t) { return "unsigned short"; }); + sm.def("overloaded", [](py::array_t) { return "long long"; }); + sm.def("overloaded", [](py::array_t>) { return "double complex"; }); + sm.def("overloaded", [](py::array_t>) { return "float complex"; }); + + sm.def("overloaded2", [](py::array_t>) { return "double complex"; }); + sm.def("overloaded2", [](py::array_t) { return "double"; }); + sm.def("overloaded2", [](py::array_t>) { return "float complex"; }); + sm.def("overloaded2", [](py::array_t) { return "float"; }); + + // Only accept the exact types: + sm.def("overloaded3", [](py::array_t) { return "int"; }, py::arg().noconvert()); + sm.def("overloaded3", [](py::array_t) { return "double"; }, py::arg().noconvert()); + + // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but + // rather that float gets converted via the safe (conversion to double) overload: + sm.def("overloaded4", [](py::array_t) { return "long long"; }); + sm.def("overloaded4", [](py::array_t) { return "double"; }); + + // But we do allow conversion to int if forcecast is enabled (but only if no overload matches + // without conversion) + sm.def("overloaded5", [](py::array_t) { return "unsigned int"; }); + sm.def("overloaded5", [](py::array_t) { return "double"; }); + // Issue 685: ndarray shouldn't go to std::string overload sm.def("issue685", [](std::string) { return "string"; }); sm.def("issue685", [](py::array) { return "array"; }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 365f4e375e..b58aa1b054 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -264,7 +264,60 @@ def test_constructors(): assert results["array_t"].dtype == np.float64 -@pytest.requires_numpy +def test_overload_resolution(msg): + from pybind11_tests.array import overloaded, overloaded2, overloaded3, overloaded4, overloaded5 + + # Exact overload matches: + assert overloaded(np.array([1], dtype='float64')) == 'double' + assert overloaded(np.array([1], dtype='float32')) == 'float' + assert overloaded(np.array([1], dtype='ushort')) == 'unsigned short' + assert overloaded(np.array([1], dtype='intc')) == 'int' + assert overloaded(np.array([1], dtype='longlong')) == 'long long' + assert overloaded(np.array([1], dtype='complex')) == 'double complex' + assert overloaded(np.array([1], dtype='csingle')) == 'float complex' + + # No exact match, should call first convertible version: + assert overloaded(np.array([1], dtype='uint8')) == 'double' + + assert overloaded2(np.array([1], dtype='float64')) == 'double' + assert overloaded2(np.array([1], dtype='float32')) == 'float' + assert overloaded2(np.array([1], dtype='complex64')) == 'float complex' + assert overloaded2(np.array([1], dtype='complex128')) == 'double complex' + assert overloaded2(np.array([1], dtype='float32')) == 'float' + + assert overloaded3(np.array([1], dtype='float64')) == 'double' + assert overloaded3(np.array([1], dtype='intc')) == 'int' + expected_exc = """ + overloaded3(): incompatible function arguments. The following argument types are supported: + 1. (arg0: numpy.ndarray[int]) -> str + 2. (arg0: numpy.ndarray[float]) -> str + + Invoked with:""" + + with pytest.raises(TypeError) as excinfo: + overloaded3(np.array([1], dtype='uintc')) + assert msg(excinfo.value) == expected_exc + " array([1], dtype=uint32)" + with pytest.raises(TypeError) as excinfo: + overloaded3(np.array([1], dtype='float32')) + assert msg(excinfo.value) == expected_exc + " array([ 1.], dtype=float32)" + with pytest.raises(TypeError) as excinfo: + overloaded3(np.array([1], dtype='complex')) + assert msg(excinfo.value) == expected_exc + " array([ 1.+0.j])" + + # Exact matches: + assert overloaded4(np.array([1], dtype='double')) == 'double' + assert overloaded4(np.array([1], dtype='longlong')) == 'long long' + # Non-exact matches requiring conversion. Since float to integer isn't a + # save conversion, it should go to the double overload, but short can go to + # either (and so should end up on the first-registered, the long long). + assert overloaded4(np.array([1], dtype='float32')) == 'double' + assert overloaded4(np.array([1], dtype='short')) == 'long long' + + assert overloaded5(np.array([1], dtype='double')) == 'double' + assert overloaded5(np.array([1], dtype='uintc')) == 'unsigned int' + assert overloaded5(np.array([1], dtype='float32')) == 'unsigned int' + + def test_greedy_string_overload(): # issue 685 from pybind11_tests.array import issue685