Skip to content

Commit 5cd4274

Browse files
hawkinspjpienaar
andauthored
[mlir python] Port in-tree dialects to nanobind. (llvm#119924)
This is a companion to llvm#118583, although it can be landed independently because since llvm#117922 dialects do not have to use the same Python binding framework as the Python core code. This PR ports all of the in-tree dialect and pass extensions to nanobind, with the exception of those that remain for testing pybind11 support. This PR also: * removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This was overlooked in a previous PR and it is duplicated in Diagnostics.h. --------- Co-authored-by: Jacques Pienaar <[email protected]>
1 parent 559f080 commit 5cd4274

35 files changed

+357
-360
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,18 @@ function(add_mlir_python_extension libname extname)
661661
NB_DOMAIN mlir
662662
${ARG_SOURCES}
663663
)
664+
665+
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
666+
# Avoids warnings from upstream nanobind.
667+
target_compile_options(nanobind-static
668+
PRIVATE
669+
-Wno-cast-qual
670+
-Wno-zero-length-array
671+
-Wno-nested-anon-types
672+
-Wno-c++98-compat-extra-semi
673+
-Wno-covered-switch-default
674+
)
675+
endif()
664676
endif()
665677

666678
# The extension itself must be compiled with RTTI and exceptions enabled.

mlir/cmake/modules/MLIRDetectPythonEnv.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,17 @@ function(mlir_detect_nanobind_install)
9595
endif()
9696
message(STATUS "found (${PACKAGE_DIR})")
9797
set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
98+
execute_process(
99+
COMMAND "${Python3_EXECUTABLE}"
100+
-c "import nanobind;print(nanobind.include_dir(), end='')"
101+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
102+
RESULT_VARIABLE STATUS
103+
OUTPUT_VARIABLE PACKAGE_DIR
104+
ERROR_QUIET)
105+
if(NOT STATUS EQUAL "0")
106+
message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
107+
return()
108+
endif()
109+
set(nanobind_INCLUDE_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
98110
endif()
99111
endfunction()

mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
//
1010
//===----------------------------------------------------------------------===//
1111

12-
#include <nanobind/nanobind.h>
13-
1412
#include "Standalone-c/Dialects.h"
13+
#include "mlir/Bindings/Python/Nanobind.h"
1514
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1615

1716
namespace nb = nanobind;
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===- Nanobind.h - Trampoline header with ignored warnings ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This file is a trampoline for the nanobind headers while disabling warnings
9+
// reported by the LLVM/MLIR build. This file avoids adding complexity build
10+
// system side.
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_BINDINGS_PYTHON_NANOBIND_H
14+
#define MLIR_BINDINGS_PYTHON_NANOBIND_H
15+
16+
#if defined(__clang__) || defined(__GNUC__)
17+
#pragma GCC diagnostic push
18+
#pragma GCC diagnostic ignored "-Wzero-length-array"
19+
#pragma GCC diagnostic ignored "-Wcast-qual"
20+
#pragma GCC diagnostic ignored "-Wnested-anon-types"
21+
#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
22+
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
23+
#endif
24+
#include <nanobind/nanobind.h>
25+
#include <nanobind/ndarray.h>
26+
#include <nanobind/stl/function.h>
27+
#include <nanobind/stl/optional.h>
28+
#include <nanobind/stl/pair.h>
29+
#include <nanobind/stl/string.h>
30+
#include <nanobind/stl/string_view.h>
31+
#include <nanobind/stl/tuple.h>
32+
#include <nanobind/stl/vector.h>
33+
#if defined(__clang__) || defined(__GNUC__)
34+
#pragma GCC diagnostic pop
35+
#endif
36+
37+
#endif // MLIR_BINDINGS_PYTHON_NANOBIND_H

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
2020
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
2121

22-
#include <nanobind/nanobind.h>
23-
#include <nanobind/stl/string.h>
24-
2522
#include <cstdint>
2623

27-
#include "mlir-c/Bindings/Python/Interop.h"
2824
#include "mlir-c/Diagnostics.h"
2925
#include "mlir-c/IR.h"
26+
#include "mlir/Bindings/Python/Nanobind.h"
27+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
3028
#include "llvm/ADT/Twine.h"
3129

3230
// Raw CAPI type casters need to be declared before use, so always include them
@@ -631,40 +629,6 @@ class mlir_value_subclass : public pure_subclass {
631629

632630
} // namespace nanobind_adaptors
633631

634-
/// RAII scope intercepting all diagnostics into a string. The message must be
635-
/// checked before this goes out of scope.
636-
class CollectDiagnosticsToStringScope {
637-
public:
638-
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
639-
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
640-
/*deleteUserData=*/nullptr);
641-
}
642-
~CollectDiagnosticsToStringScope() {
643-
assert(errorMessage.empty() && "unchecked error message");
644-
mlirContextDetachDiagnosticHandler(context, handlerID);
645-
}
646-
647-
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
648-
649-
private:
650-
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
651-
auto printer = +[](MlirStringRef message, void *data) {
652-
*static_cast<std::string *>(data) +=
653-
llvm::StringRef(message.data, message.length);
654-
};
655-
MlirLocation loc = mlirDiagnosticGetLocation(diag);
656-
*static_cast<std::string *>(data) += "at ";
657-
mlirLocationPrint(loc, printer, data);
658-
*static_cast<std::string *>(data) += ": ";
659-
mlirDiagnosticPrint(diag, printer, data);
660-
return mlirLogicalResultSuccess();
661-
}
662-
663-
MlirContext context;
664-
MlirDiagnosticHandlerID handlerID;
665-
std::string errorMessage = "";
666-
};
667-
668632
} // namespace python
669633
} // namespace mlir
670634

mlir/lib/Bindings/Python/AsyncPasses.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88

99
#include "mlir-c/Dialect/Async.h"
1010

11-
#include <pybind11/detail/common.h>
12-
#include <pybind11/pybind11.h>
11+
#include "mlir/Bindings/Python/Nanobind.h"
1312

1413
// -----------------------------------------------------------------------------
1514
// Module initialization.
1615
// -----------------------------------------------------------------------------
1716

18-
PYBIND11_MODULE(_mlirAsyncPasses, m) {
17+
NB_MODULE(_mlirAsyncPasses, m) {
1918
m.doc() = "MLIR Async Dialect Passes";
2019

2120
// Register all Async passes on load.

mlir/lib/Bindings/Python/DialectGPU.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,21 @@
99
#include "mlir-c/Dialect/GPU.h"
1010
#include "mlir-c/IR.h"
1111
#include "mlir-c/Support.h"
12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
13+
#include "mlir/Bindings/Python/Nanobind.h"
1314

14-
#include <pybind11/detail/common.h>
15-
#include <pybind11/pybind11.h>
15+
namespace nb = nanobind;
16+
using namespace nanobind::literals;
1617

17-
namespace py = pybind11;
1818
using namespace mlir;
1919
using namespace mlir::python;
20-
using namespace mlir::python::adaptors;
20+
using namespace mlir::python::nanobind_adaptors;
2121

2222
// -----------------------------------------------------------------------------
2323
// Module initialization.
2424
// -----------------------------------------------------------------------------
2525

26-
PYBIND11_MODULE(_mlirDialectsGPU, m) {
26+
NB_MODULE(_mlirDialectsGPU, m) {
2727
m.doc() = "MLIR GPU Dialect";
2828
//===-------------------------------------------------------------------===//
2929
// AsyncTokenType
@@ -34,11 +34,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
3434

3535
mlirGPUAsyncTokenType.def_classmethod(
3636
"get",
37-
[](py::object cls, MlirContext ctx) {
37+
[](nb::object cls, MlirContext ctx) {
3838
return cls(mlirGPUAsyncTokenTypeGet(ctx));
3939
},
40-
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
41-
py::arg("ctx") = py::none());
40+
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
41+
nb::arg("ctx").none() = nb::none());
4242

4343
//===-------------------------------------------------------------------===//
4444
// ObjectAttr
@@ -47,12 +47,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
4747
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
4848
.def_classmethod(
4949
"get",
50-
[](py::object cls, MlirAttribute target, uint32_t format,
51-
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
50+
[](nb::object cls, MlirAttribute target, uint32_t format,
51+
nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
5252
std::optional<MlirAttribute> mlirKernelsAttr) {
53-
py::buffer_info info(py::buffer(object).request());
54-
MlirStringRef objectStrRef =
55-
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
53+
MlirStringRef objectStrRef = mlirStringRefCreate(
54+
static_cast<char *>(const_cast<void *>(object.data())),
55+
object.size());
5656
return cls(mlirGPUObjectAttrGetWithKernels(
5757
mlirAttributeGetContext(target), target, format, objectStrRef,
5858
mlirObjectProps.has_value() ? *mlirObjectProps
@@ -61,7 +61,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
6161
: MlirAttribute{nullptr}));
6262
},
6363
"cls"_a, "target"_a, "format"_a, "object"_a,
64-
"properties"_a = py::none(), "kernels"_a = py::none(),
64+
"properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
6565
"Gets a gpu.object from parameters.")
6666
.def_property_readonly(
6767
"target",
@@ -73,18 +73,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
7373
"object",
7474
[](MlirAttribute self) {
7575
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
76-
return py::bytes(stringRef.data, stringRef.length);
76+
return nb::bytes(stringRef.data, stringRef.length);
7777
})
7878
.def_property_readonly("properties",
79-
[](MlirAttribute self) {
79+
[](MlirAttribute self) -> nb::object {
8080
if (mlirGPUObjectAttrHasProperties(self))
81-
return py::cast(
81+
return nb::cast(
8282
mlirGPUObjectAttrGetProperties(self));
83-
return py::none().cast<py::object>();
83+
return nb::none();
8484
})
85-
.def_property_readonly("kernels", [](MlirAttribute self) {
85+
.def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
8686
if (mlirGPUObjectAttrHasKernels(self))
87-
return py::cast(mlirGPUObjectAttrGetKernels(self));
88-
return py::none().cast<py::object>();
87+
return nb::cast(mlirGPUObjectAttrGetKernels(self));
88+
return nb::none();
8989
});
9090
}

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@
1212
#include "mlir-c/IR.h"
1313
#include "mlir-c/Support.h"
1414
#include "mlir/Bindings/Python/Diagnostics.h"
15-
#include "mlir/Bindings/Python/PybindAdaptors.h"
15+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
16+
#include "mlir/Bindings/Python/Nanobind.h"
17+
18+
namespace nb = nanobind;
19+
20+
using namespace nanobind::literals;
1621

17-
namespace py = pybind11;
1822
using namespace llvm;
1923
using namespace mlir;
2024
using namespace mlir::python;
21-
using namespace mlir::python::adaptors;
25+
using namespace mlir::python::nanobind_adaptors;
2226

23-
void populateDialectLLVMSubmodule(const pybind11::module &m) {
27+
void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
2428

2529
//===--------------------------------------------------------------------===//
2630
// StructType
@@ -31,58 +35,58 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
3135

3236
llvmStructType.def_classmethod(
3337
"get_literal",
34-
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
38+
[](nb::object cls, const std::vector<MlirType> &elements, bool packed,
3539
MlirLocation loc) {
3640
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
3741

3842
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
3943
loc, elements.size(), elements.data(), packed);
4044
if (mlirTypeIsNull(type)) {
41-
throw py::value_error(scope.takeMessage());
45+
throw nb::value_error(scope.takeMessage().c_str());
4246
}
4347
return cls(type);
4448
},
45-
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
46-
"loc"_a = py::none());
49+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
50+
"loc"_a.none() = nb::none());
4751

4852
llvmStructType.def_classmethod(
4953
"get_identified",
50-
[](py::object cls, const std::string &name, MlirContext context) {
54+
[](nb::object cls, const std::string &name, MlirContext context) {
5155
return cls(mlirLLVMStructTypeIdentifiedGet(
5256
context, mlirStringRefCreate(name.data(), name.size())));
5357
},
54-
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
58+
"cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());
5559

5660
llvmStructType.def_classmethod(
5761
"get_opaque",
58-
[](py::object cls, const std::string &name, MlirContext context) {
62+
[](nb::object cls, const std::string &name, MlirContext context) {
5963
return cls(mlirLLVMStructTypeOpaqueGet(
6064
context, mlirStringRefCreate(name.data(), name.size())));
6165
},
62-
"cls"_a, "name"_a, "context"_a = py::none());
66+
"cls"_a, "name"_a, "context"_a.none() = nb::none());
6367

6468
llvmStructType.def(
6569
"set_body",
6670
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
6771
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
6872
self, elements.size(), elements.data(), packed);
6973
if (!mlirLogicalResultIsSuccess(result)) {
70-
throw py::value_error(
74+
throw nb::value_error(
7175
"Struct body already set to different content.");
7276
}
7377
},
74-
"elements"_a, py::kw_only(), "packed"_a = false);
78+
"elements"_a, nb::kw_only(), "packed"_a = false);
7579

7680
llvmStructType.def_classmethod(
7781
"new_identified",
78-
[](py::object cls, const std::string &name,
82+
[](nb::object cls, const std::string &name,
7983
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
8084
return cls(mlirLLVMStructTypeIdentifiedNewGet(
8185
ctx, mlirStringRefCreate(name.data(), name.length()),
8286
elements.size(), elements.data(), packed));
8387
},
84-
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
85-
"context"_a = py::none());
88+
"cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
89+
"context"_a.none() = nb::none());
8690

8791
llvmStructType.def_property_readonly(
8892
"name", [](MlirType type) -> std::optional<std::string> {
@@ -93,12 +97,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
9397
return StringRef(stringRef.data, stringRef.length).str();
9498
});
9599

96-
llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
100+
llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
97101
// Don't crash in absence of a body.
98102
if (mlirLLVMStructTypeIsOpaque(type))
99-
return py::none();
103+
return nb::none();
100104

101-
py::list body;
105+
nb::list body;
102106
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
103107
++i) {
104108
body.append(mlirLLVMStructTypeGetElementType(type, i));
@@ -119,24 +123,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
119123
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
120124
.def_classmethod(
121125
"get",
122-
[](py::object cls, std::optional<unsigned> addressSpace,
126+
[](nb::object cls, std::optional<unsigned> addressSpace,
123127
MlirContext context) {
124128
CollectDiagnosticsToStringScope scope(context);
125129
MlirType type = mlirLLVMPointerTypeGet(
126130
context, addressSpace.has_value() ? *addressSpace : 0);
127131
if (mlirTypeIsNull(type)) {
128-
throw py::value_error(scope.takeMessage());
132+
throw nb::value_error(scope.takeMessage().c_str());
129133
}
130134
return cls(type);
131135
},
132-
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
133-
"context"_a = py::none())
136+
"cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
137+
"context"_a.none() = nb::none())
134138
.def_property_readonly("address_space", [](MlirType type) {
135139
return mlirLLVMPointerTypeGetAddressSpace(type);
136140
});
137141
}
138142

139-
PYBIND11_MODULE(_mlirDialectsLLVM, m) {
143+
NB_MODULE(_mlirDialectsLLVM, m) {
140144
m.doc() = "MLIR LLVM Dialect";
141145

142146
populateDialectLLVMSubmodule(m);

0 commit comments

Comments
 (0)