Skip to content

Commit 69d656b

Browse files
committed
[mlir python] Port in-tree dialects to nanobind.
This is a companion to #118583, although it can be landed independently because since #117922 dialects do not have to use the same Python binding framework as the Python core code. This PR ports all of the in-tree dialect and pass extensions to nanobind, with the exception of those that remain for testing pybind11 support. It would make sense to merge this PR after merging #118583, if we have agreed that we are migrating the core to nanobind. This PR also: * removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This was overlooked in a previous PR and it is duplicated in Diagnostics.h.
1 parent 310e798 commit 69d656b

20 files changed

+288
-296
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

+3
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ function(declare_mlir_python_extension name)
142142
mlir_python_DEPENDS ""
143143
mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}"
144144
)
145+
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
146+
set_target_properties(${name} PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic")
147+
endif()
145148

146149
# Set the interface source and link_libs properties of the target
147150
# These properties support generator expressions and are automatically exported

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

-34
Original file line numberDiff line numberDiff line change
@@ -631,40 +631,6 @@ class mlir_value_subclass : public pure_subclass {
631631

632632
} // namespace nanobind_adaptors
633633

634-
/// RAII scope intercepting all diagnostics into a string. The message must be
635-
/// checked before this goes out of scope.
636-
class CollectDiagnosticsToStringScope {
637-
public:
638-
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
639-
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
640-
/*deleteUserData=*/nullptr);
641-
}
642-
~CollectDiagnosticsToStringScope() {
643-
assert(errorMessage.empty() && "unchecked error message");
644-
mlirContextDetachDiagnosticHandler(context, handlerID);
645-
}
646-
647-
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
648-
649-
private:
650-
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
651-
auto printer = +[](MlirStringRef message, void *data) {
652-
*static_cast<std::string *>(data) +=
653-
llvm::StringRef(message.data, message.length);
654-
};
655-
MlirLocation loc = mlirDiagnosticGetLocation(diag);
656-
*static_cast<std::string *>(data) += "at ";
657-
mlirLocationPrint(loc, printer, data);
658-
*static_cast<std::string *>(data) += ": ";
659-
mlirDiagnosticPrint(diag, printer, data);
660-
return mlirLogicalResultSuccess();
661-
}
662-
663-
MlirContext context;
664-
MlirDiagnosticHandlerID handlerID;
665-
std::string errorMessage = "";
666-
};
667-
668634
} // namespace python
669635
} // namespace mlir
670636

mlir/lib/Bindings/Python/AsyncPasses.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88

99
#include "mlir-c/Dialect/Async.h"
1010

11-
#include <pybind11/detail/common.h>
12-
#include <pybind11/pybind11.h>
11+
#include <nanobind/nanobind.h>
1312

1413
// -----------------------------------------------------------------------------
1514
// Module initialization.
1615
// -----------------------------------------------------------------------------
1716

18-
PYBIND11_MODULE(_mlirAsyncPasses, m) {
17+
NB_MODULE(_mlirAsyncPasses, m) {
1918
m.doc() = "MLIR Async Dialect Passes";
2019

2120
// Register all Async passes on load.

mlir/lib/Bindings/Python/DialectGPU.cpp

+24-22
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@
99
#include "mlir-c/Dialect/GPU.h"
1010
#include "mlir-c/IR.h"
1111
#include "mlir-c/Support.h"
12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

14-
#include <pybind11/detail/common.h>
15-
#include <pybind11/pybind11.h>
14+
#include <nanobind/nanobind.h>
15+
#include <nanobind/stl/optional.h>
16+
17+
namespace nb = nanobind;
18+
using namespace nanobind::literals;
1619

17-
namespace py = pybind11;
1820
using namespace mlir;
1921
using namespace mlir::python;
20-
using namespace mlir::python::adaptors;
22+
using namespace mlir::python::nanobind_adaptors;
2123

2224
// -----------------------------------------------------------------------------
2325
// Module initialization.
2426
// -----------------------------------------------------------------------------
2527

26-
PYBIND11_MODULE(_mlirDialectsGPU, m) {
28+
NB_MODULE(_mlirDialectsGPU, m) {
2729
m.doc() = "MLIR GPU Dialect";
2830
//===-------------------------------------------------------------------===//
2931
// AsyncTokenType
@@ -34,11 +36,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
3436

3537
mlirGPUAsyncTokenType.def_classmethod(
3638
"get",
37-
[](py::object cls, MlirContext ctx) {
39+
[](nb::object cls, MlirContext ctx) {
3840
return cls(mlirGPUAsyncTokenTypeGet(ctx));
3941
},
40-
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
41-
py::arg("ctx") = py::none());
42+
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
43+
nb::arg("ctx").none() = nb::none());
4244

4345
//===-------------------------------------------------------------------===//
4446
// ObjectAttr
@@ -47,12 +49,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
4749
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
4850
.def_classmethod(
4951
"get",
50-
[](py::object cls, MlirAttribute target, uint32_t format,
51-
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
52+
[](nb::object cls, MlirAttribute target, uint32_t format,
53+
nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
5254
std::optional<MlirAttribute> mlirKernelsAttr) {
53-
py::buffer_info info(py::buffer(object).request());
54-
MlirStringRef objectStrRef =
55-
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
55+
MlirStringRef objectStrRef = mlirStringRefCreate(
56+
static_cast<char *>(const_cast<void *>(object.data())),
57+
object.size());
5658
return cls(mlirGPUObjectAttrGetWithKernels(
5759
mlirAttributeGetContext(target), target, format, objectStrRef,
5860
mlirObjectProps.has_value() ? *mlirObjectProps
@@ -61,7 +63,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
6163
: MlirAttribute{nullptr}));
6264
},
6365
"cls"_a, "target"_a, "format"_a, "object"_a,
64-
"properties"_a = py::none(), "kernels"_a = py::none(),
66+
"properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
6567
"Gets a gpu.object from parameters.")
6668
.def_property_readonly(
6769
"target",
@@ -73,18 +75,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
7375
"object",
7476
[](MlirAttribute self) {
7577
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
76-
return py::bytes(stringRef.data, stringRef.length);
78+
return nb::bytes(stringRef.data, stringRef.length);
7779
})
7880
.def_property_readonly("properties",
79-
[](MlirAttribute self) {
81+
[](MlirAttribute self) -> nb::object {
8082
if (mlirGPUObjectAttrHasProperties(self))
81-
return py::cast(
83+
return nb::cast(
8284
mlirGPUObjectAttrGetProperties(self));
83-
return py::none().cast<py::object>();
85+
return nb::none();
8486
})
85-
.def_property_readonly("kernels", [](MlirAttribute self) {
87+
.def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
8688
if (mlirGPUObjectAttrHasKernels(self))
87-
return py::cast(mlirGPUObjectAttrGetKernels(self));
88-
return py::none().cast<py::object>();
89+
return nb::cast(mlirGPUObjectAttrGetKernels(self));
90+
return nb::none();
8991
});
9092
}

mlir/lib/Bindings/Python/DialectLLVM.cpp

+33-25
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,23 @@
1212
#include "mlir-c/IR.h"
1313
#include "mlir-c/Support.h"
1414
#include "mlir/Bindings/Python/Diagnostics.h"
15-
#include "mlir/Bindings/Python/PybindAdaptors.h"
15+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
16+
17+
#include <nanobind/nanobind.h>
18+
#include <nanobind/stl/optional.h>
19+
#include <nanobind/stl/string.h>
20+
#include <nanobind/stl/vector.h>
21+
22+
namespace nb = nanobind;
23+
24+
using namespace nanobind::literals;
1625

17-
namespace py = pybind11;
1826
using namespace llvm;
1927
using namespace mlir;
2028
using namespace mlir::python;
21-
using namespace mlir::python::adaptors;
29+
using namespace mlir::python::nanobind_adaptors;
2230

23-
void populateDialectLLVMSubmodule(const pybind11::module &m) {
31+
void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
2432

2533
//===--------------------------------------------------------------------===//
2634
// StructType
@@ -31,58 +39,58 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
3139

3240
llvmStructType.def_classmethod(
3341
"get_literal",
34-
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
42+
[](nb::object cls, const std::vector<MlirType> &elements, bool packed,
3543
MlirLocation loc) {
3644
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
3745

3846
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
3947
loc, elements.size(), elements.data(), packed);
4048
if (mlirTypeIsNull(type)) {
41-
throw py::value_error(scope.takeMessage());
49+
throw nb::value_error(scope.takeMessage().c_str());
4250
}
4351
return cls(type);
4452
},
45-
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
46-
"loc"_a = py::none());
53+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
54+
"loc"_a.none() = nb::none());
4755

4856
llvmStructType.def_classmethod(
4957
"get_identified",
50-
[](py::object cls, const std::string &name, MlirContext context) {
58+
[](nb::object cls, const std::string &name, MlirContext context) {
5159
return cls(mlirLLVMStructTypeIdentifiedGet(
5260
context, mlirStringRefCreate(name.data(), name.size())));
5361
},
54-
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
62+
"cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());
5563

5664
llvmStructType.def_classmethod(
5765
"get_opaque",
58-
[](py::object cls, const std::string &name, MlirContext context) {
66+
[](nb::object cls, const std::string &name, MlirContext context) {
5967
return cls(mlirLLVMStructTypeOpaqueGet(
6068
context, mlirStringRefCreate(name.data(), name.size())));
6169
},
62-
"cls"_a, "name"_a, "context"_a = py::none());
70+
"cls"_a, "name"_a, "context"_a.none() = nb::none());
6371

6472
llvmStructType.def(
6573
"set_body",
6674
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
6775
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
6876
self, elements.size(), elements.data(), packed);
6977
if (!mlirLogicalResultIsSuccess(result)) {
70-
throw py::value_error(
78+
throw nb::value_error(
7179
"Struct body already set to different content.");
7280
}
7381
},
74-
"elements"_a, py::kw_only(), "packed"_a = false);
82+
"elements"_a, nb::kw_only(), "packed"_a = false);
7583

7684
llvmStructType.def_classmethod(
7785
"new_identified",
78-
[](py::object cls, const std::string &name,
86+
[](nb::object cls, const std::string &name,
7987
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
8088
return cls(mlirLLVMStructTypeIdentifiedNewGet(
8189
ctx, mlirStringRefCreate(name.data(), name.length()),
8290
elements.size(), elements.data(), packed));
8391
},
84-
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
85-
"context"_a = py::none());
92+
"cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
93+
"context"_a.none() = nb::none());
8694

8795
llvmStructType.def_property_readonly(
8896
"name", [](MlirType type) -> std::optional<std::string> {
@@ -93,12 +101,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
93101
return StringRef(stringRef.data, stringRef.length).str();
94102
});
95103

96-
llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
104+
llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
97105
// Don't crash in absence of a body.
98106
if (mlirLLVMStructTypeIsOpaque(type))
99-
return py::none();
107+
return nb::none();
100108

101-
py::list body;
109+
nb::list body;
102110
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
103111
++i) {
104112
body.append(mlirLLVMStructTypeGetElementType(type, i));
@@ -119,24 +127,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
119127
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
120128
.def_classmethod(
121129
"get",
122-
[](py::object cls, std::optional<unsigned> addressSpace,
130+
[](nb::object cls, std::optional<unsigned> addressSpace,
123131
MlirContext context) {
124132
CollectDiagnosticsToStringScope scope(context);
125133
MlirType type = mlirLLVMPointerTypeGet(
126134
context, addressSpace.has_value() ? *addressSpace : 0);
127135
if (mlirTypeIsNull(type)) {
128-
throw py::value_error(scope.takeMessage());
136+
throw nb::value_error(scope.takeMessage().c_str());
129137
}
130138
return cls(type);
131139
},
132-
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
133-
"context"_a = py::none())
140+
"cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
141+
"context"_a.none() = nb::none())
134142
.def_property_readonly("address_space", [](MlirType type) {
135143
return mlirLLVMPointerTypeGetAddressSpace(type);
136144
});
137145
}
138146

139-
PYBIND11_MODULE(_mlirDialectsLLVM, m) {
147+
NB_MODULE(_mlirDialectsLLVM, m) {
140148
m.doc() = "MLIR LLVM Dialect";
141149

142150
populateDialectLLVMSubmodule(m);

mlir/lib/Bindings/Python/DialectLinalg.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,24 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <nanobind/nanobind.h>
10+
911
#include "mlir-c/Dialect/Linalg.h"
1012
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/PybindAdaptors.h"
13+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1214

13-
namespace py = pybind11;
15+
namespace nb = nanobind;
1416

15-
static void populateDialectLinalgSubmodule(py::module m) {
17+
static void populateDialectLinalgSubmodule(nb::module_ m) {
1618
m.def(
1719
"fill_builtin_region",
1820
[](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
19-
py::arg("op"),
21+
nb::arg("op"),
2022
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2123
"op.");
2224
}
2325

24-
PYBIND11_MODULE(_mlirDialectsLinalg, m) {
26+
NB_MODULE(_mlirDialectsLinalg, m) {
2527
m.doc() = "MLIR Linalg dialect.";
2628

2729
populateDialectLinalgSubmodule(m);

mlir/lib/Bindings/Python/DialectNVGPU.cpp

+11-10
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,36 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <nanobind/nanobind.h>
10+
911
#include "mlir-c/Dialect/NVGPU.h"
1012
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12-
#include <pybind11/pybind11.h>
13+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1314

14-
namespace py = pybind11;
15+
namespace nb = nanobind;
1516
using namespace llvm;
1617
using namespace mlir;
1718
using namespace mlir::python;
18-
using namespace mlir::python::adaptors;
19+
using namespace mlir::python::nanobind_adaptors;
1920

20-
static void populateDialectNVGPUSubmodule(const pybind11::module &m) {
21+
static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
2122
auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
2223
m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
2324

2425
nvgpuTensorMapDescriptorType.def_classmethod(
2526
"get",
26-
[](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
27+
[](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
2728
int oobFill, int interleave, MlirContext ctx) {
2829
return cls(mlirNVGPUTensorMapDescriptorTypeGet(
2930
ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
3031
},
3132
"Gets an instance of TensorMapDescriptorType in the same context",
32-
py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"),
33-
py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"),
34-
py::arg("ctx") = py::none());
33+
nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
34+
nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
35+
nb::arg("ctx").none() = nb::none());
3536
}
3637

37-
PYBIND11_MODULE(_mlirDialectsNVGPU, m) {
38+
NB_MODULE(_mlirDialectsNVGPU, m) {
3839
m.doc() = "MLIR NVGPU dialect.";
3940

4041
populateDialectNVGPUSubmodule(m);

0 commit comments

Comments
 (0)