@@ -15,17 +15,22 @@ limitations under the License.
15
15
16
16
#include " stablehlo/integrations/python/StablehloApi.h"
17
17
18
+ #include < stdexcept>
18
19
#include < string>
19
20
#include < string_view>
20
21
21
22
#include " llvm/Support/raw_ostream.h"
22
23
#include " mlir-c/BuiltinAttributes.h"
23
24
#include " mlir-c/IR.h"
24
25
#include " mlir-c/Support.h"
25
- #include " mlir/Bindings/Python/PybindAdaptors.h"
26
+ #include " mlir/Bindings/Python/NanobindAdaptors.h"
27
+ #include " nanobind/nanobind.h"
28
+ #include " nanobind/stl/string.h"
29
+ #include " nanobind/stl/string_view.h"
30
+ #include " nanobind/stl/vector.h"
26
31
#include " stablehlo/integrations/c/StablehloApi.h"
27
32
28
- namespace py = pybind11 ;
33
+ namespace nb = nanobind ;
29
34
30
35
namespace mlir {
31
36
namespace stablehlo {
@@ -63,14 +68,18 @@ static MlirStringRef toMlirStringRef(std::string_view s) {
63
68
return mlirStringRefCreate (s.data (), s.size ());
64
69
}
65
70
66
- void AddStablehloApi (py::module &m) {
71
+ static MlirStringRef toMlirStringRef (const nb::bytes &s) {
72
+ return mlirStringRefCreate (static_cast <const char *>(s.data ()), s.size ());
73
+ }
74
+
75
+ void AddStablehloApi (nb::module_ &m) {
67
76
// Portable API is a subset of StableHLO API
68
77
AddPortableApi (m);
69
78
70
79
//
71
80
// Utility APIs.
72
81
//
73
- py ::enum_<MlirStablehloCompatibilityRequirement>(
82
+ nb ::enum_<MlirStablehloCompatibilityRequirement>(
74
83
m, " StablehloCompatibilityRequirement" )
75
84
.value (" NONE" , MlirStablehloCompatibilityRequirement::NONE)
76
85
.value (" WEEK_4" , MlirStablehloCompatibilityRequirement::WEEK_4)
@@ -79,48 +88,57 @@ void AddStablehloApi(py::module &m) {
79
88
80
89
m.def (
81
90
" get_version_from_compatibility_requirement" ,
82
- [](MlirStablehloCompatibilityRequirement requirement) -> py::str {
91
+ [](MlirStablehloCompatibilityRequirement requirement) -> std::string {
83
92
StringWriterHelper accumulator;
84
93
stablehloVersionFromCompatibilityRequirement (
85
94
requirement, accumulator.getMlirStringCallback (),
86
95
accumulator.getUserData ());
87
96
return accumulator.toString ();
88
97
},
89
- py ::arg (" requirement" ));
98
+ nb ::arg (" requirement" ));
90
99
91
100
//
92
101
// Serialization APIs.
93
102
//
94
103
m.def (
95
104
" serialize_portable_artifact" ,
96
- [](MlirModule module, std::string_view target) -> py ::bytes {
105
+ [](MlirModule module, std::string_view target) -> nb ::bytes {
97
106
StringWriterHelper accumulator;
98
107
if (mlirLogicalResultIsFailure (
99
108
stablehloSerializePortableArtifactFromModule (
100
109
module, toMlirStringRef (target),
101
110
accumulator.getMlirStringCallback (),
102
111
accumulator.getUserData ()))) {
103
- PyErr_SetString (PyExc_ValueError, " failed to serialize module" );
104
- return " " ;
112
+ throw nb::value_error (" failed to serialize module" );
105
113
}
106
114
107
- return py::bytes (accumulator.toString ());
115
+ std::string serialized = accumulator.toString ();
116
+ return nb::bytes (serialized.data (), serialized.size ());
108
117
},
109
- py ::arg (" module" ), py ::arg (" target" ));
118
+ nb ::arg (" module" ), nb ::arg (" target" ));
110
119
111
120
m.def (
112
121
" deserialize_portable_artifact" ,
113
122
[](MlirContext context, std::string_view artifact) -> MlirModule {
114
123
auto module = stablehloDeserializePortableArtifactNoError (
115
124
toMlirStringRef (artifact), context);
116
125
if (mlirModuleIsNull (module)) {
117
- PyErr_SetString (PyExc_ValueError, " failed to deserialize module" );
118
- return {};
126
+ throw nb::value_error (" failed to deserialize module" );
119
127
}
120
128
return module;
121
129
},
122
- py::arg (" context" ), py::arg (" artifact" ));
123
-
130
+ nb::arg (" context" ), nb::arg (" artifact" ));
131
+ m.def (
132
+ " deserialize_portable_artifact" ,
133
+ [](MlirContext context, nb::bytes artifact) -> MlirModule {
134
+ auto module = stablehloDeserializePortableArtifactNoError (
135
+ toMlirStringRef (artifact), context);
136
+ if (mlirModuleIsNull (module)) {
137
+ throw nb::value_error (" failed to deserialize module" );
138
+ }
139
+ return module;
140
+ },
141
+ nb::arg (" context" ), nb::arg (" artifact" ));
124
142
//
125
143
// Reference APIs
126
144
//
@@ -130,9 +148,7 @@ void AddStablehloApi(py::module &m) {
130
148
std::vector<MlirAttribute> &args) -> std::vector<MlirAttribute> {
131
149
for (auto arg : args) {
132
150
if (!mlirAttributeIsADenseElements (arg)) {
133
- PyErr_SetString (PyExc_ValueError,
134
- " input args must be DenseElementsAttr" );
135
- return {};
151
+ throw nb::value_error (" input args must be DenseElementsAttr" );
136
152
}
137
153
}
138
154
@@ -141,8 +157,7 @@ void AddStablehloApi(py::module &m) {
141
157
stablehloEvalModule (module, args.size (), args.data (), &errorCode);
142
158
143
159
if (errorCode != 0 ) {
144
- PyErr_SetString (PyExc_ValueError, " interpreter failed" );
145
- return {};
160
+ throw nb::value_error (" interpreter failed" );
146
161
}
147
162
148
163
std::vector<MlirAttribute> pyResults;
@@ -151,39 +166,39 @@ void AddStablehloApi(py::module &m) {
151
166
}
152
167
return pyResults;
153
168
},
154
- py ::arg (" module" ), py ::arg (" args" ));
169
+ nb ::arg (" module" ), nb ::arg (" args" ));
155
170
}
156
171
157
- void AddPortableApi (py::module &m) {
172
+ void AddPortableApi (nb::module_ &m) {
158
173
//
159
174
// Utility APIs.
160
175
//
161
176
m.def (" get_api_version" , []() { return stablehloGetApiVersion (); });
162
177
163
178
m.def (
164
179
" get_smaller_version" ,
165
- [](const std::string &version1, const std::string &version2) -> py::str {
180
+ [](const std::string &version1,
181
+ const std::string &version2) -> std::string {
166
182
StringWriterHelper accumulator;
167
183
if (mlirLogicalResultIsFailure (stablehloGetSmallerVersion (
168
184
toMlirStringRef (version1), toMlirStringRef (version2),
169
185
accumulator.getMlirStringCallback (),
170
186
accumulator.getUserData ()))) {
171
- PyErr_SetString (PyExc_ValueError,
172
- " failed to convert version to stablehlo version" );
173
- return " " ;
187
+ throw nb::value_error (
188
+ " failed to convert version to stablehlo version" );
174
189
}
175
190
return accumulator.toString ();
176
191
},
177
- py ::arg (" version1" ), py ::arg (" version2" ));
192
+ nb ::arg (" version1" ), nb ::arg (" version2" ));
178
193
179
- m.def (" get_current_version" , []() -> py::str {
194
+ m.def (" get_current_version" , []() -> std::string {
180
195
StringWriterHelper accumulator;
181
196
stablehloGetCurrentVersion (accumulator.getMlirStringCallback (),
182
197
accumulator.getUserData ());
183
198
return accumulator.toString ();
184
199
});
185
200
186
- m.def (" get_minimum_version" , []() -> py::str {
201
+ m.def (" get_minimum_version" , []() -> std::string {
187
202
StringWriterHelper accumulator;
188
203
stablehloGetMinimumVersion (accumulator.getMlirStringCallback (),
189
204
accumulator.getUserData ());
@@ -196,34 +211,64 @@ void AddPortableApi(py::module &m) {
196
211
m.def (
197
212
" serialize_portable_artifact_str" ,
198
213
[](std::string_view moduleStrOrBytecode,
199
- std::string_view targetVersion) -> py::bytes {
214
+ std::string_view targetVersion) -> nb::bytes {
215
+ StringWriterHelper accumulator;
216
+ if (mlirLogicalResultIsFailure (
217
+ stablehloSerializePortableArtifactFromStringRef (
218
+ toMlirStringRef (moduleStrOrBytecode),
219
+ toMlirStringRef (targetVersion),
220
+ accumulator.getMlirStringCallback (),
221
+ accumulator.getUserData ()))) {
222
+ throw nb::value_error (" failed to serialize module" );
223
+ }
224
+ std::string serialized = accumulator.toString ();
225
+ return nb::bytes (serialized.data (), serialized.size ());
226
+ },
227
+ nb::arg (" module_str" ), nb::arg (" target_version" ));
228
+ m.def (
229
+ " serialize_portable_artifact_str" ,
230
+ [](nb::bytes moduleStrOrBytecode,
231
+ std::string_view targetVersion) -> nb::bytes {
200
232
StringWriterHelper accumulator;
201
233
if (mlirLogicalResultIsFailure (
202
234
stablehloSerializePortableArtifactFromStringRef (
203
235
toMlirStringRef (moduleStrOrBytecode),
204
236
toMlirStringRef (targetVersion),
205
237
accumulator.getMlirStringCallback (),
206
238
accumulator.getUserData ()))) {
207
- PyErr_SetString (PyExc_ValueError, " failed to serialize module" );
208
- return " " ;
239
+ throw nb::value_error (" failed to serialize module" );
209
240
}
210
- return py::bytes (accumulator.toString ());
241
+ std::string serialized = accumulator.toString ();
242
+ return nb::bytes (serialized.data (), serialized.size ());
211
243
},
212
- py ::arg (" module_str" ), py ::arg (" target_version" ));
244
+ nb ::arg (" module_str" ), nb ::arg (" target_version" ));
213
245
214
246
m.def (
215
247
" deserialize_portable_artifact_str" ,
216
- [](std::string_view artifact) -> py::bytes {
248
+ [](std::string_view artifact) -> nb::bytes {
249
+ StringWriterHelper accumulator;
250
+ if (mlirLogicalResultIsFailure (stablehloDeserializePortableArtifact (
251
+ toMlirStringRef (artifact), accumulator.getMlirStringCallback (),
252
+ accumulator.getUserData ()))) {
253
+ throw nb::value_error (" failed to deserialize module" );
254
+ }
255
+ std::string serialized = accumulator.toString ();
256
+ return nb::bytes (serialized.data (), serialized.size ());
257
+ },
258
+ nb::arg (" artifact_str" ));
259
+ m.def (
260
+ " deserialize_portable_artifact_str" ,
261
+ [](const nb::bytes& artifact) -> nb::bytes {
217
262
StringWriterHelper accumulator;
218
263
if (mlirLogicalResultIsFailure (stablehloDeserializePortableArtifact (
219
264
toMlirStringRef (artifact), accumulator.getMlirStringCallback (),
220
265
accumulator.getUserData ()))) {
221
- PyErr_SetString (PyExc_ValueError, " failed to deserialize module" );
222
- return " " ;
266
+ throw nb::value_error (" failed to deserialize module" );
223
267
}
224
- return py::bytes (accumulator.toString ());
268
+ std::string serialized = accumulator.toString ();
269
+ return nb::bytes (serialized.data (), serialized.size ());
225
270
},
226
- py ::arg (" artifact_str" ));
271
+ nb ::arg (" artifact_str" ));
227
272
}
228
273
229
274
} // namespace stablehlo
0 commit comments