Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions misc/test-stubgenc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,28 @@ EXIT=0
function stubgenc_test() {
# Remove expected stubs and generate new inplace
STUBGEN_OUTPUT_FOLDER=./test-data/pybind11_mypy_demo/$1
rm -rf "${STUBGEN_OUTPUT_FOLDER:?}/*"
rm -rf "${STUBGEN_OUTPUT_FOLDER:?}"

stubgen -o "$STUBGEN_OUTPUT_FOLDER" "${@:2}"

# Check if generated stubs can actually be type checked by mypy
if ! mypy "$STUBGEN_OUTPUT_FOLDER";
then
echo "Stubgen test failed, because generated stubs failed to type check."
EXIT=1
fi

# Compare generated stubs to expected ones
if ! git diff --exit-code "$STUBGEN_OUTPUT_FOLDER";
then
echo "Stubgen test failed, because generated stubs differ from expected outputs."
EXIT=1
fi
}

# create stubs without docstrings
stubgenc_test stubgen -p pybind11_mypy_demo
stubgenc_test expected_stubs_no_docs -p pybind11_mypy_demo
# create stubs with docstrings
stubgenc_test stubgen-include-docs -p pybind11_mypy_demo --include-docstrings
stubgenc_test expected_stubs_with_docs -p pybind11_mypy_demo --include-docstrings

exit $EXIT
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from . import demo as demo
from typing import List, Optional, Tuple

class TestStruct:
field_readwrite: int
field_readwrite_docstring: int
def __init__(self, *args, **kwargs) -> None: ...
@property
def field_readonly(self) -> int: ...

def func_incomplete_signature(*args, **kwargs): ...
def func_returning_optional() -> Optional[int]: ...
def func_returning_pair() -> Tuple[int, float]: ...
def func_returning_path() -> os.PathLike: ...
def func_returning_vector() -> List[float]: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from . import demo as demo
from typing import List, Optional, Tuple

class TestStruct:
field_readwrite: int
field_readwrite_docstring: int
def __init__(self, *args, **kwargs) -> None:
"""Initialize self. See help(type(self)) for accurate signature."""
@property
def field_readonly(self) -> int: ...

def func_incomplete_signature(*args, **kwargs):
"""func_incomplete_signature() -> dummy_sub_namespace::HasNoBinding"""
def func_returning_optional() -> Optional[int]:
"""func_returning_optional() -> Optional[int]"""
def func_returning_pair() -> Tuple[int, float]:
"""func_returning_pair() -> Tuple[int, float]"""
def func_returning_path() -> os.PathLike:
"""func_returning_path() -> os.PathLike"""
def func_returning_vector() -> List[float]:
"""func_returning_vector() -> List[float]"""
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ class Point:
degree: ClassVar[Point.AngleUnit] = ...
radian: ClassVar[Point.AngleUnit] = ...
def __init__(self, value: int) -> None:
"""__init__(self: pybind11_mypy_demo.basics.Point.AngleUnit, value: int) -> None"""
"""__init__(self: pybind11_mypy_demo.demo.Point.AngleUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
"""__index__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
"""__index__(self: pybind11_mypy_demo.demo.Point.AngleUnit) -> int"""
def __int__(self) -> int:
"""__int__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
"""__int__(self: pybind11_mypy_demo.demo.Point.AngleUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
@property
Expand All @@ -33,15 +33,15 @@ class Point:
mm: ClassVar[Point.LengthUnit] = ...
pixel: ClassVar[Point.LengthUnit] = ...
def __init__(self, value: int) -> None:
"""__init__(self: pybind11_mypy_demo.basics.Point.LengthUnit, value: int) -> None"""
"""__init__(self: pybind11_mypy_demo.demo.Point.LengthUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
"""__index__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
"""__index__(self: pybind11_mypy_demo.demo.Point.LengthUnit) -> int"""
def __int__(self) -> int:
"""__int__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
"""__int__(self: pybind11_mypy_demo.demo.Point.LengthUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
@property
Expand All @@ -60,38 +60,38 @@ class Point:
"""__init__(*args, **kwargs)
Overloaded function.

1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
1. __init__(self: pybind11_mypy_demo.demo.Point) -> None

2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
2. __init__(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> None
"""
@overload
def __init__(self, x: float, y: float) -> None:
"""__init__(*args, **kwargs)
Overloaded function.

1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
1. __init__(self: pybind11_mypy_demo.demo.Point) -> None

2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
2. __init__(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> None
"""
def as_list(self) -> List[float]:
"""as_list(self: pybind11_mypy_demo.basics.Point) -> List[float]"""
"""as_list(self: pybind11_mypy_demo.demo.Point) -> List[float]"""
@overload
def distance_to(self, x: float, y: float) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.

1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
1. distance_to(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> float

2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
2. distance_to(self: pybind11_mypy_demo.demo.Point, other: pybind11_mypy_demo.demo.Point) -> float
"""
@overload
def distance_to(self, other: Point) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.

1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
1. distance_to(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> float

2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
2. distance_to(self: pybind11_mypy_demo.demo.Point, other: pybind11_mypy_demo.demo.Point) -> float
"""
@property
def length(self) -> float: ...
Expand Down
107 changes: 94 additions & 13 deletions test-data/pybind11_mypy_demo/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,85 @@
*/

#include <cmath>
#include <filesystem>
#include <optional>
#include <utility>
#include <vector>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl/filesystem.h>

namespace py = pybind11;

namespace basics {
// ----------------------------------------------------------------------------
// Dedicated test cases
// ----------------------------------------------------------------------------

std::vector<float> funcReturningVector()
{
return std::vector<float>{1.0, 2.0, 3.0};
}

std::pair<int, float> funcReturningPair()
{
return std::pair{42, 1.0};
}

std::optional<int> funcReturningOptional()
{
return std::nullopt;
}

std::filesystem::path funcReturningPath()
{
return std::filesystem::path{"foobar"};
}

namespace dummy_sub_namespace {
struct HasNoBinding{};
}

// We can enforce the case of an incomplete signature by referring to a type in
// some namespace that doesn't have a pybind11 binding.
dummy_sub_namespace::HasNoBinding funcIncompleteSignature()
{
return dummy_sub_namespace::HasNoBinding{};
}

struct TestStruct
{
int field_readwrite;
int field_readwrite_docstring;
int field_readonly;
};

// Bindings

void bind_test_cases(py::module& m) {
m.def("func_returning_vector", &funcReturningVector);
m.def("func_returning_pair", &funcReturningPair);
m.def("func_returning_optional", &funcReturningOptional);
m.def("func_returning_path", &funcReturningPath);

m.def("func_incomplete_signature", &funcIncompleteSignature);

py::class_<TestStruct>(m, "TestStruct")
.def_readwrite("field_readwrite", &TestStruct::field_readwrite)
.def_readwrite("field_readwrite_docstring", &TestStruct::field_readwrite_docstring, "some docstring")
.def_property_readonly(
"field_readonly",
[](const TestStruct& x) {
return x.field_readonly;
},
"some docstring");
}

// ----------------------------------------------------------------------------
// Original demo
// ----------------------------------------------------------------------------

namespace demo {

int answer() {
return 42;
Expand Down Expand Up @@ -118,20 +191,22 @@ const Point Point::y_axis = Point(0, 1);
Point::LengthUnit Point::length_unit = Point::LengthUnit::mm;
Point::AngleUnit Point::angle_unit = Point::AngleUnit::radian;

} // namespace: basics
} // namespace: demo

void bind_basics(py::module& basics) {
// Bindings

using namespace basics;
void bind_demo(py::module& m) {

using namespace demo;

// Functions
basics.def("answer", &answer, "answer docstring, with end quote\""); // tests explicit docstrings
basics.def("sum", &sum, "multiline docstring test, edge case quotes \"\"\"'''");
basics.def("midpoint", &midpoint, py::arg("left"), py::arg("right"));
basics.def("weighted_midpoint", weighted_midpoint, py::arg("left"), py::arg("right"), py::arg("alpha")=0.5);
m.def("answer", &answer, "answer docstring, with end quote\""); // tests explicit docstrings
m.def("sum", &sum, "multiline docstring test, edge case quotes \"\"\"'''");
m.def("midpoint", &midpoint, py::arg("left"), py::arg("right"));
m.def("weighted_midpoint", weighted_midpoint, py::arg("left"), py::arg("right"), py::arg("alpha")=0.5);

// Classes
py::class_<Point> pyPoint(basics, "Point");
py::class_<Point> pyPoint(m, "Point");
py::enum_<Point::LengthUnit> pyLengthUnit(pyPoint, "LengthUnit");
py::enum_<Point::AngleUnit> pyAngleUnit(pyPoint, "AngleUnit");

Expand Down Expand Up @@ -167,11 +242,17 @@ void bind_basics(py::module& basics) {
.value("degree", Point::AngleUnit::degree);

// Module-level attributes
basics.attr("PI") = std::acos(-1);
basics.attr("__version__") = "0.0.1";
m.attr("PI") = std::acos(-1);
m.attr("__version__") = "0.0.1";
}

// ----------------------------------------------------------------------------
// Module entry point
// ----------------------------------------------------------------------------

PYBIND11_MODULE(pybind11_mypy_demo, m) {
auto basics = m.def_submodule("basics");
bind_basics(basics);
bind_test_cases(m);

auto demo = m.def_submodule("demo");
bind_demo(demo);
}

This file was deleted.

This file was deleted.