Skip to content

Commit 6374488

Browse files
committed
Adding string constructor for enum
1 parent 8fbb559 commit 6374488

File tree

4 files changed

+60
-0
lines changed

4 files changed

+60
-0
lines changed

docs/classes.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,21 @@ The ``name`` property returns the name of the enum value as a unicode string.
506506
>>> pet_type.name
507507
'Cat'
508508
509+
You can also access the enumeration using a string using the enum's constructor,
510+
such as ``Pet('Cat')``. This makes it possible to automatically convert a string
511+
to an enumeration in an API if the enumeration is marked implicitly convertible
512+
from a string, with a line such as:
513+
514+
.. code-block:: cpp
515+
516+
py::implicitly_convertible<std::string, Pet::Kind>();
517+
518+
Now, in Python, the following code will also correctly construct a cat:
519+
520+
.. code-block:: pycon
521+
522+
>>> p = Pet('Lucy', 'Cat')
523+
509524
.. note::
510525

511526
When the special tag ``py::arithmetic()`` is specified to the ``enum_``

include/pybind11/pybind11.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,14 @@ template <typename Type> class enum_ : public class_<Type> {
14111411
return m;
14121412
}, return_value_policy::copy);
14131413
def(init([](Scalar i) { return static_cast<Type>(i); }));
1414+
def(init([name, m_entries_ptr](std::string value) -> Type {
1415+
pybind11::dict values = reinterpret_borrow<pybind11::dict>(m_entries_ptr);
1416+
pybind11::str key = pybind11::str(value);
1417+
if (values.contains(key))
1418+
return pybind11::cast<Type>(values[key]);
1419+
else
1420+
throw value_error("\"" + value + "\" is not a valid value for enum type " + name);
1421+
}));
14141422
def("__int__", [](Type value) { return (Scalar) value; });
14151423
#if PY_MAJOR_VERSION < 3
14161424
def("__long__", [](Type value) { return (Scalar) value; });

tests/test_enum.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ TEST_SUBMODULE(enums, m) {
2020
.value("ETwo", ETwo, "Docstring for ETwo")
2121
.export_values();
2222

23+
// test_conversion_enum
24+
enum class ConversionEnum {
25+
Convert1 = 1,
26+
Convert2
27+
};
28+
29+
py::enum_<ConversionEnum>(m, "ConversionEnum", py::arithmetic())
30+
.value("Convert1", ConversionEnum::Convert1)
31+
.value("Convert2", ConversionEnum::Convert2)
32+
;
33+
py::implicitly_convertible<py::str, ConversionEnum>();
34+
35+
m.def("test_conversion_enum", [](ConversionEnum z) {
36+
return "ConversionEnum::" + std::string(z == ConversionEnum::Convert1 ? "Convert1" : "Convert2");
37+
});
38+
2339
// test_scoped_enum
2440
enum class ScopedEnum {
2541
Two = 2,

tests/test_enum.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_unscoped_enum():
5454

5555
assert int(m.UnscopedEnum.ETwo) == 2
5656
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
57+
assert str(m.UnscopedEnum("ETwo")) == "UnscopedEnum.ETwo"
5758

5859
# order
5960
assert m.UnscopedEnum.EOne < m.UnscopedEnum.ETwo
@@ -70,8 +71,28 @@ def test_unscoped_enum():
7071
assert not (2 < m.UnscopedEnum.EOne)
7172

7273

74+
def test_converstion_enum():
75+
assert m.test_conversion_enum(m.ConversionEnum.Convert1) == "ConversionEnum::Convert1"
76+
assert m.test_conversion_enum(m.ConversionEnum("Convert1")) == "ConversionEnum::Convert1"
77+
assert m.test_conversion_enum("Convert1") == "ConversionEnum::Convert1"
78+
79+
80+
def test_conversion_enum_raises():
81+
with pytest.raises(ValueError) as excinfo:
82+
m.ConversionEnum("Convert0")
83+
assert str(excinfo.value) == "\"Convert0\" is not a valid value for enum type ConversionEnum"
84+
85+
86+
def test_conversion_enum_raises_implicit():
87+
with pytest.raises(ValueError) as excinfo:
88+
m.test_conversion_enum("Convert0")
89+
assert str(excinfo.value) == "\"Convert0\" is not a valid value for enum type ConversionEnum"
90+
91+
7392
def test_scoped_enum():
7493
assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three"
94+
with pytest.raises(TypeError):
95+
m.test_scoped_enum("Three")
7596
z = m.ScopedEnum.Two
7697
assert m.test_scoped_enum(z) == "ScopedEnum::Two"
7798

0 commit comments

Comments
 (0)