Skip to content

Commit b02e440

Browse files
committed
Adding string constructor for enum
1 parent 64a99b9 commit b02e440

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

include/pybind11/pybind11.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1358,7 +1358,7 @@ template <typename Type> class enum_ : public class_<Type> {
13581358

13591359
template <typename... Extra>
13601360
enum_(const handle &scope, const char *name, const Extra&... extra)
1361-
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
1361+
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope), m_name(name) {
13621362

13631363
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
13641364

@@ -1377,6 +1377,15 @@ template <typename Type> class enum_ : public class_<Type> {
13771377
return m;
13781378
}, return_value_policy::copy);
13791379
def(init([](Scalar i) { return static_cast<Type>(i); }));
1380+
def(init([this, m_entries_ptr](std::string value) -> Type {
1381+
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
1382+
std::string key = cast<str>(kv.first);
1383+
if(value == key || key == m_name + "::" + value) {
1384+
return cast<Type>(kv.second);
1385+
}
1386+
}
1387+
throw value_error("\"" + value + "\" is not a valid value for enum type " + m_name);
1388+
}));
13801389
def("__int__", [](Type value) { return (Scalar) value; });
13811390
#if PY_MAJOR_VERSION < 3
13821391
def("__long__", [](Type value) { return (Scalar) value; });
@@ -1436,6 +1445,7 @@ template <typename Type> class enum_ : public class_<Type> {
14361445
private:
14371446
dict m_entries;
14381447
handle m_parent;
1448+
std::string m_name;
14391449
};
14401450

14411451
NAMESPACE_BEGIN(detail)

tests/test_enum.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ TEST_SUBMODULE(enums, m) {
2020
.value("ETwo", ETwo)
2121
.export_values();
2222

23+
// test_conversion_enum
24+
enum class ConversionEnum {
25+
Convert1 = 1,
26+
Convert2
27+
};
28+
29+
py::enum_<ConversionEnum>(m, "ConversionEnum", py::arithmetic())
30+
.value("Convert1", ConversionEnum::Convert1)
31+
.value("Convert2", ConversionEnum::Convert2)
32+
;
33+
py::implicitly_convertible<py::str, ConversionEnum>();
34+
35+
m.def("test_conversion_enum", [](ConversionEnum z) {
36+
return "ConversionEnum::" + std::string(z == ConversionEnum::Convert1 ? "Convert1" : "Convert2");
37+
});
38+
2339
// test_scoped_enum
2440
enum class ScopedEnum {
2541
Two = 2,

tests/test_enum.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_unscoped_enum():
2525

2626
assert int(m.UnscopedEnum.ETwo) == 2
2727
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
28+
assert str(m.UnscopedEnum("ETwo")) == "UnscopedEnum.ETwo"
2829

2930
# order
3031
assert m.UnscopedEnum.EOne < m.UnscopedEnum.ETwo
@@ -40,9 +41,17 @@ def test_unscoped_enum():
4041
assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne)
4142
assert not (2 < m.UnscopedEnum.EOne)
4243

44+
def test_converstion_enum():
45+
assert m.test_conversion_enum(m.ConversionEnum.Convert1) == "ConversionEnum::Convert1"
46+
assert m.test_conversion_enum(m.ConversionEnum("Convert1")) == "ConversionEnum::Convert1"
47+
assert m.test_conversion_enum("Convert1") == "ConversionEnum::Convert1"
48+
assert m.test_conversion_enum(m.ConversionEnum.Convert1) == "ConversionEnum::Convert1"
49+
4350

4451
def test_scoped_enum():
4552
assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three"
53+
with pytest.raises(TypeError):
54+
m.test_scoped_enum("Three")
4655
z = m.ScopedEnum.Two
4756
assert m.test_scoped_enum(z) == "ScopedEnum::Two"
4857

0 commit comments

Comments
 (0)