Skip to content

Commit 1c1ab38

Browse files
hawkinspTensorFlow MLIR Team
authored and
TensorFlow MLIR Team
committed
Migrate StableHLO Python extension to nanobind.
I'm working towards moving the MLIR Python core code to use nanobind instead of pybind11: * llvm/llvm-project#117922, which was merged recently, allows downstream Python dialect extensions to be defined using either pybind11 or nanobind. * llvm/llvm-project#118583 is a PR in review that ports the Python core code to use nanobind instead of pybind11. This PR migrates StableHLO and related dialects to use nanobind rather than pybind11, with the goal of migrating JAX away from pybind11. PiperOrigin-RevId: 707537037
1 parent c22d98e commit 1c1ab38

File tree

6 files changed

+209
-144
lines changed

6 files changed

+209
-144
lines changed

stablehlo/stablehlo/integrations/python/CheckModule.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ limitations under the License.
1111
==============================================================================*/
1212

1313
#include "mlir-c/IR.h"
14-
#include "mlir/Bindings/Python/PybindAdaptors.h"
14+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
15+
#include "nanobind/nanobind.h"
1516
#include "stablehlo/integrations/c/CheckDialect.h"
1617

17-
namespace py = pybind11;
18+
namespace nb = nanobind;
1819

19-
PYBIND11_MODULE(_check, m) {
20+
NB_MODULE(_check, m) {
2021
m.doc() = "check main python extension";
2122

2223
//
@@ -32,5 +33,5 @@ PYBIND11_MODULE(_check, m) {
3233
mlirDialectHandleLoadDialect(dialect, context);
3334
}
3435
},
35-
py::arg("context"), py::arg("load") = true);
36+
nb::arg("context"), nb::arg("load") = true);
3637
}

stablehlo/stablehlo/integrations/python/ChloModule.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,23 @@ limitations under the License.
1212
==============================================================================*/
1313

1414
#include "mlir-c/IR.h"
15-
#include "mlir/Bindings/Python/PybindAdaptors.h"
15+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
16+
#include "nanobind/nanobind.h"
17+
#include "nanobind/stl/string_view.h"
1618
#include "stablehlo/integrations/c/ChloAttributes.h"
1719
#include "stablehlo/integrations/c/ChloDialect.h"
1820

19-
namespace py = pybind11;
21+
namespace nb = nanobind;
2022

2123
namespace {
2224

2325
auto toPyString(MlirStringRef mlirStringRef) {
24-
return py::str(mlirStringRef.data, mlirStringRef.length);
26+
return nb::str(mlirStringRef.data, mlirStringRef.length);
2527
}
2628

2729
} // namespace
2830

29-
PYBIND11_MODULE(_chlo, m) {
31+
NB_MODULE(_chlo, m) {
3032
m.doc() = "chlo main python extension";
3133

3234
//
@@ -42,35 +44,37 @@ PYBIND11_MODULE(_chlo, m) {
4244
mlirDialectHandleLoadDialect(dialect, context);
4345
}
4446
},
45-
py::arg("context"), py::arg("load") = true);
47+
nb::arg("context"), nb::arg("load") = true);
4648

4749
//
4850
// Attributes.
4951
//
5052

51-
mlir::python::adaptors::mlir_attribute_subclass(
53+
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
5254
m, "ComparisonDirectionAttr", chloAttributeIsAComparisonDirectionAttr)
5355
.def_classmethod(
5456
"get",
55-
[](py::object cls, const std::string &value, MlirContext ctx) {
57+
[](nb::object cls, std::string_view value, MlirContext ctx) {
5658
return cls(chloComparisonDirectionAttrGet(
57-
ctx, mlirStringRefCreate(value.c_str(), value.size())));
59+
ctx, mlirStringRefCreate(value.data(), value.size())));
5860
},
59-
py::arg("cls"), py::arg("value"), py::arg("context") = py::none(),
61+
nb::arg("cls"), nb::arg("value"),
62+
nb::arg("context").none() = nb::none(),
6063
"Creates a ComparisonDirection attribute with the given value.")
6164
.def_property_readonly("value", [](MlirAttribute self) {
6265
return toPyString(chloComparisonDirectionAttrGetValue(self));
6366
});
6467

65-
mlir::python::adaptors::mlir_attribute_subclass(
68+
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
6669
m, "ComparisonTypeAttr", chloAttributeIsAComparisonTypeAttr)
6770
.def_classmethod(
6871
"get",
69-
[](py::object cls, const std::string &value, MlirContext ctx) {
72+
[](nb::object cls, std::string_view value, MlirContext ctx) {
7073
return cls(chloComparisonTypeAttrGet(
71-
ctx, mlirStringRefCreate(value.c_str(), value.size())));
74+
ctx, mlirStringRefCreate(value.data(), value.size())));
7275
},
73-
py::arg("cls"), py::arg("value"), py::arg("context") = py::none(),
76+
nb::arg("cls"), nb::arg("value"),
77+
nb::arg("context").none() = nb::none(),
7478
"Creates a ComparisonType attribute with the given value.")
7579
.def_property_readonly("value", [](MlirAttribute self) {
7680
return toPyString(chloComparisonTypeAttrGetValue(self));

stablehlo/stablehlo/integrations/python/StablehloApi.cpp

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@ limitations under the License.
1515

1616
#include "stablehlo/integrations/python/StablehloApi.h"
1717

18+
#include <stdexcept>
1819
#include <string>
1920
#include <string_view>
2021

2122
#include "llvm/Support/raw_ostream.h"
2223
#include "mlir-c/BuiltinAttributes.h"
2324
#include "mlir-c/IR.h"
2425
#include "mlir-c/Support.h"
25-
#include "mlir/Bindings/Python/PybindAdaptors.h"
26+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
27+
#include "nanobind/nanobind.h"
28+
#include "nanobind/stl/string.h"
29+
#include "nanobind/stl/string_view.h"
30+
#include "nanobind/stl/vector.h"
2631
#include "stablehlo/integrations/c/StablehloApi.h"
2732

28-
namespace py = pybind11;
33+
namespace nb = nanobind;
2934

3035
namespace mlir {
3136
namespace stablehlo {
@@ -63,14 +68,18 @@ static MlirStringRef toMlirStringRef(std::string_view s) {
6368
return mlirStringRefCreate(s.data(), s.size());
6469
}
6570

66-
void AddStablehloApi(py::module &m) {
71+
static MlirStringRef toMlirStringRef(const nb::bytes &s) {
72+
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
73+
}
74+
75+
void AddStablehloApi(nb::module_ &m) {
6776
// Portable API is a subset of StableHLO API
6877
AddPortableApi(m);
6978

7079
//
7180
// Utility APIs.
7281
//
73-
py::enum_<MlirStablehloCompatibilityRequirement>(
82+
nb::enum_<MlirStablehloCompatibilityRequirement>(
7483
m, "StablehloCompatibilityRequirement")
7584
.value("NONE", MlirStablehloCompatibilityRequirement::NONE)
7685
.value("WEEK_4", MlirStablehloCompatibilityRequirement::WEEK_4)
@@ -79,48 +88,57 @@ void AddStablehloApi(py::module &m) {
7988

8089
m.def(
8190
"get_version_from_compatibility_requirement",
82-
[](MlirStablehloCompatibilityRequirement requirement) -> py::str {
91+
[](MlirStablehloCompatibilityRequirement requirement) -> std::string {
8392
StringWriterHelper accumulator;
8493
stablehloVersionFromCompatibilityRequirement(
8594
requirement, accumulator.getMlirStringCallback(),
8695
accumulator.getUserData());
8796
return accumulator.toString();
8897
},
89-
py::arg("requirement"));
98+
nb::arg("requirement"));
9099

91100
//
92101
// Serialization APIs.
93102
//
94103
m.def(
95104
"serialize_portable_artifact",
96-
[](MlirModule module, std::string_view target) -> py::bytes {
105+
[](MlirModule module, std::string_view target) -> nb::bytes {
97106
StringWriterHelper accumulator;
98107
if (mlirLogicalResultIsFailure(
99108
stablehloSerializePortableArtifactFromModule(
100109
module, toMlirStringRef(target),
101110
accumulator.getMlirStringCallback(),
102111
accumulator.getUserData()))) {
103-
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
104-
return "";
112+
throw nb::value_error("failed to serialize module");
105113
}
106114

107-
return py::bytes(accumulator.toString());
115+
std::string serialized = accumulator.toString();
116+
return nb::bytes(serialized.data(), serialized.size());
108117
},
109-
py::arg("module"), py::arg("target"));
118+
nb::arg("module"), nb::arg("target"));
110119

111120
m.def(
112121
"deserialize_portable_artifact",
113122
[](MlirContext context, std::string_view artifact) -> MlirModule {
114123
auto module = stablehloDeserializePortableArtifactNoError(
115124
toMlirStringRef(artifact), context);
116125
if (mlirModuleIsNull(module)) {
117-
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
118-
return {};
126+
throw nb::value_error("failed to deserialize module");
119127
}
120128
return module;
121129
},
122-
py::arg("context"), py::arg("artifact"));
123-
130+
nb::arg("context"), nb::arg("artifact"));
131+
m.def(
132+
"deserialize_portable_artifact",
133+
[](MlirContext context, nb::bytes artifact) -> MlirModule {
134+
auto module = stablehloDeserializePortableArtifactNoError(
135+
toMlirStringRef(artifact), context);
136+
if (mlirModuleIsNull(module)) {
137+
throw nb::value_error("failed to deserialize module");
138+
}
139+
return module;
140+
},
141+
nb::arg("context"), nb::arg("artifact"));
124142
//
125143
// Reference APIs
126144
//
@@ -130,9 +148,7 @@ void AddStablehloApi(py::module &m) {
130148
std::vector<MlirAttribute> &args) -> std::vector<MlirAttribute> {
131149
for (auto arg : args) {
132150
if (!mlirAttributeIsADenseElements(arg)) {
133-
PyErr_SetString(PyExc_ValueError,
134-
"input args must be DenseElementsAttr");
135-
return {};
151+
throw nb::value_error("input args must be DenseElementsAttr");
136152
}
137153
}
138154

@@ -141,8 +157,7 @@ void AddStablehloApi(py::module &m) {
141157
stablehloEvalModule(module, args.size(), args.data(), &errorCode);
142158

143159
if (errorCode != 0) {
144-
PyErr_SetString(PyExc_ValueError, "interpreter failed");
145-
return {};
160+
throw nb::value_error("interpreter failed");
146161
}
147162

148163
std::vector<MlirAttribute> pyResults;
@@ -151,39 +166,39 @@ void AddStablehloApi(py::module &m) {
151166
}
152167
return pyResults;
153168
},
154-
py::arg("module"), py::arg("args"));
169+
nb::arg("module"), nb::arg("args"));
155170
}
156171

157-
void AddPortableApi(py::module &m) {
172+
void AddPortableApi(nb::module_ &m) {
158173
//
159174
// Utility APIs.
160175
//
161176
m.def("get_api_version", []() { return stablehloGetApiVersion(); });
162177

163178
m.def(
164179
"get_smaller_version",
165-
[](const std::string &version1, const std::string &version2) -> py::str {
180+
[](const std::string &version1,
181+
const std::string &version2) -> std::string {
166182
StringWriterHelper accumulator;
167183
if (mlirLogicalResultIsFailure(stablehloGetSmallerVersion(
168184
toMlirStringRef(version1), toMlirStringRef(version2),
169185
accumulator.getMlirStringCallback(),
170186
accumulator.getUserData()))) {
171-
PyErr_SetString(PyExc_ValueError,
172-
"failed to convert version to stablehlo version");
173-
return "";
187+
throw nb::value_error(
188+
"failed to convert version to stablehlo version");
174189
}
175190
return accumulator.toString();
176191
},
177-
py::arg("version1"), py::arg("version2"));
192+
nb::arg("version1"), nb::arg("version2"));
178193

179-
m.def("get_current_version", []() -> py::str {
194+
m.def("get_current_version", []() -> std::string {
180195
StringWriterHelper accumulator;
181196
stablehloGetCurrentVersion(accumulator.getMlirStringCallback(),
182197
accumulator.getUserData());
183198
return accumulator.toString();
184199
});
185200

186-
m.def("get_minimum_version", []() -> py::str {
201+
m.def("get_minimum_version", []() -> std::string {
187202
StringWriterHelper accumulator;
188203
stablehloGetMinimumVersion(accumulator.getMlirStringCallback(),
189204
accumulator.getUserData());
@@ -196,34 +211,64 @@ void AddPortableApi(py::module &m) {
196211
m.def(
197212
"serialize_portable_artifact_str",
198213
[](std::string_view moduleStrOrBytecode,
199-
std::string_view targetVersion) -> py::bytes {
214+
std::string_view targetVersion) -> nb::bytes {
215+
StringWriterHelper accumulator;
216+
if (mlirLogicalResultIsFailure(
217+
stablehloSerializePortableArtifactFromStringRef(
218+
toMlirStringRef(moduleStrOrBytecode),
219+
toMlirStringRef(targetVersion),
220+
accumulator.getMlirStringCallback(),
221+
accumulator.getUserData()))) {
222+
throw nb::value_error("failed to serialize module");
223+
}
224+
std::string serialized = accumulator.toString();
225+
return nb::bytes(serialized.data(), serialized.size());
226+
},
227+
nb::arg("module_str"), nb::arg("target_version"));
228+
m.def(
229+
"serialize_portable_artifact_str",
230+
[](nb::bytes moduleStrOrBytecode,
231+
std::string_view targetVersion) -> nb::bytes {
200232
StringWriterHelper accumulator;
201233
if (mlirLogicalResultIsFailure(
202234
stablehloSerializePortableArtifactFromStringRef(
203235
toMlirStringRef(moduleStrOrBytecode),
204236
toMlirStringRef(targetVersion),
205237
accumulator.getMlirStringCallback(),
206238
accumulator.getUserData()))) {
207-
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
208-
return "";
239+
throw nb::value_error("failed to serialize module");
209240
}
210-
return py::bytes(accumulator.toString());
241+
std::string serialized = accumulator.toString();
242+
return nb::bytes(serialized.data(), serialized.size());
211243
},
212-
py::arg("module_str"), py::arg("target_version"));
244+
nb::arg("module_str"), nb::arg("target_version"));
213245

214246
m.def(
215247
"deserialize_portable_artifact_str",
216-
[](std::string_view artifact) -> py::bytes {
248+
[](std::string_view artifact) -> nb::bytes {
249+
StringWriterHelper accumulator;
250+
if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
251+
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
252+
accumulator.getUserData()))) {
253+
throw nb::value_error("failed to deserialize module");
254+
}
255+
std::string serialized = accumulator.toString();
256+
return nb::bytes(serialized.data(), serialized.size());
257+
},
258+
nb::arg("artifact_str"));
259+
m.def(
260+
"deserialize_portable_artifact_str",
261+
[](const nb::bytes& artifact) -> nb::bytes {
217262
StringWriterHelper accumulator;
218263
if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
219264
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
220265
accumulator.getUserData()))) {
221-
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
222-
return "";
266+
throw nb::value_error("failed to deserialize module");
223267
}
224-
return py::bytes(accumulator.toString());
268+
std::string serialized = accumulator.toString();
269+
return nb::bytes(serialized.data(), serialized.size());
225270
},
226-
py::arg("artifact_str"));
271+
nb::arg("artifact_str"));
227272
}
228273

229274
} // namespace stablehlo

stablehlo/stablehlo/integrations/python/StablehloApi.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@ limitations under the License.
1616
#ifndef STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H
1717
#define STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H
1818

19-
#include "pybind11/pybind11.h"
19+
#include "nanobind/nanobind.h"
2020

2121
namespace mlir {
2222
namespace stablehlo {
2323

24-
// Add StableHLO APIs to the pybind11 module.
24+
// Add StableHLO APIs to the nanobind module.
2525
// Signatures of these APIs have no dependency on C++ MLIR types and all must
2626
// use C API passthrough.
27-
void AddStablehloApi(pybind11::module& m);
27+
void AddStablehloApi(nanobind::module_& m);
2828

2929
// Adds a subset of the StableHLO API that doesn't use MLIR in any definitions,
3030
// and is methods only, introducing no new objects / enums to avoid potential
3131
// redefinition issues in complex build environments.
32-
void AddPortableApi(pybind11::module& m);
32+
void AddPortableApi(nanobind::module_& m);
3333

3434
} // namespace stablehlo
3535
} // namespace mlir

0 commit comments

Comments
 (0)