Skip to content

[mlir python] Port in-tree dialects to nanobind. #119924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/cmake/modules/AddMLIRPython.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,18 @@ function(add_mlir_python_extension libname extname)
NB_DOMAIN mlir
${ARG_SOURCES}
)

if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
target_compile_options(nanobind-static
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array
-Wno-nested-anon-types
-Wno-c++98-compat-extra-semi
-Wno-covered-switch-default
)
endif()
endif()

# The extension itself must be compiled with RTTI and exceptions enabled.
Expand Down
12 changes: 12 additions & 0 deletions mlir/cmake/modules/MLIRDetectPythonEnv.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,17 @@ function(mlir_detect_nanobind_install)
endif()
message(STATUS "found (${PACKAGE_DIR})")
set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
execute_process(
COMMAND "${Python3_EXECUTABLE}"
-c "import nanobind;print(nanobind.include_dir(), end='')"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE STATUS
OUTPUT_VARIABLE PACKAGE_DIR
ERROR_QUIET)
if(NOT STATUS EQUAL "0")
message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
return()
endif()
set(nanobind_INCLUDE_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
endif()
endfunction()
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
//
//===----------------------------------------------------------------------===//

#include <nanobind/nanobind.h>

#include "Standalone-c/Dialects.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Bindings/Python/Nanobind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- Nanobind.h - Trampoline header with ignored warnings ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file is a trampoline for the nanobind headers while disabling warnings
// reported by the LLVM/MLIR build. This file avoids adding complexity build
// system side.
//===----------------------------------------------------------------------===//

#ifndef MLIR_BINDINGS_PYTHON_NANOBIND_H
#define MLIR_BINDINGS_PYTHON_NANOBIND_H

#if defined(__clang__) || defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wzero-length-array"
#pragma GCC diagnostic ignored "-Wcast-qual"
#pragma GCC diagnostic ignored "-Wnested-anon-types"
#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
#endif
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/string_view.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/vector.h>
#if defined(__clang__) || defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

#endif // MLIR_BINDINGS_PYTHON_NANOBIND_H
40 changes: 2 additions & 38 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H

#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>

#include <cstdint>

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "llvm/ADT/Twine.h"

// Raw CAPI type casters need to be declared before use, so always include them
Expand Down Expand Up @@ -631,40 +629,6 @@ class mlir_value_subclass : public pure_subclass {

} // namespace nanobind_adaptors

/// RAII scope intercepting all diagnostics into a string. The message must be
/// checked before this goes out of scope.
class CollectDiagnosticsToStringScope {
public:
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
assert(errorMessage.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}

[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }

private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
*static_cast<std::string *>(data) +=
llvm::StringRef(message.data, message.length);
};
MlirLocation loc = mlirDiagnosticGetLocation(diag);
*static_cast<std::string *>(data) += "at ";
mlirLocationPrint(loc, printer, data);
*static_cast<std::string *>(data) += ": ";
mlirDiagnosticPrint(diag, printer, data);
return mlirLogicalResultSuccess();
}

MlirContext context;
MlirDiagnosticHandlerID handlerID;
std::string errorMessage = "";
};

} // namespace python
} // namespace mlir

Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Bindings/Python/AsyncPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@

#include "mlir-c/Dialect/Async.h"

#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include "mlir/Bindings/Python/Nanobind.h"

// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------

PYBIND11_MODULE(_mlirAsyncPasses, m) {
NB_MODULE(_mlirAsyncPasses, m) {
m.doc() = "MLIR Async Dialect Passes";

// Register all Async passes on load.
Expand Down
44 changes: 22 additions & 22 deletions mlir/lib/Bindings/Python/DialectGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"

#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
namespace nb = nanobind;
using namespace nanobind::literals;

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
using namespace mlir::python::nanobind_adaptors;

// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------

PYBIND11_MODULE(_mlirDialectsGPU, m) {
NB_MODULE(_mlirDialectsGPU, m) {
m.doc() = "MLIR GPU Dialect";
//===-------------------------------------------------------------------===//
// AsyncTokenType
Expand All @@ -34,11 +34,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {

mlirGPUAsyncTokenType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
[](nb::object cls, MlirContext ctx) {
return cls(mlirGPUAsyncTokenTypeGet(ctx));
},
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
py::arg("ctx") = py::none());
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
nb::arg("ctx").none() = nb::none());

//===-------------------------------------------------------------------===//
// ObjectAttr
Expand All @@ -47,12 +47,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
.def_classmethod(
"get",
[](py::object cls, MlirAttribute target, uint32_t format,
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
[](nb::object cls, MlirAttribute target, uint32_t format,
nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
std::optional<MlirAttribute> mlirKernelsAttr) {
py::buffer_info info(py::buffer(object).request());
MlirStringRef objectStrRef =
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
MlirStringRef objectStrRef = mlirStringRefCreate(
static_cast<char *>(const_cast<void *>(object.data())),
object.size());
return cls(mlirGPUObjectAttrGetWithKernels(
mlirAttributeGetContext(target), target, format, objectStrRef,
mlirObjectProps.has_value() ? *mlirObjectProps
Expand All @@ -61,7 +61,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
: MlirAttribute{nullptr}));
},
"cls"_a, "target"_a, "format"_a, "object"_a,
"properties"_a = py::none(), "kernels"_a = py::none(),
"properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
"Gets a gpu.object from parameters.")
.def_property_readonly(
"target",
Expand All @@ -73,18 +73,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
"object",
[](MlirAttribute self) {
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
return py::bytes(stringRef.data, stringRef.length);
return nb::bytes(stringRef.data, stringRef.length);
})
.def_property_readonly("properties",
[](MlirAttribute self) {
[](MlirAttribute self) -> nb::object {
if (mlirGPUObjectAttrHasProperties(self))
return py::cast(
return nb::cast(
mlirGPUObjectAttrGetProperties(self));
return py::none().cast<py::object>();
return nb::none();
})
.def_property_readonly("kernels", [](MlirAttribute self) {
.def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
if (mlirGPUObjectAttrHasKernels(self))
return py::cast(mlirGPUObjectAttrGetKernels(self));
return py::none().cast<py::object>();
return nb::cast(mlirGPUObjectAttrGetKernels(self));
return nb::none();
});
}
54 changes: 29 additions & 25 deletions mlir/lib/Bindings/Python/DialectLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"

namespace nb = nanobind;

using namespace nanobind::literals;

namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
using namespace mlir::python::nanobind_adaptors;

void populateDialectLLVMSubmodule(const pybind11::module &m) {
void populateDialectLLVMSubmodule(const nanobind::module_ &m) {

//===--------------------------------------------------------------------===//
// StructType
Expand All @@ -31,58 +35,58 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {

llvmStructType.def_classmethod(
"get_literal",
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
[](nb::object cls, const std::vector<MlirType> &elements, bool packed,
MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));

MlirType type = mlirLLVMStructTypeLiteralGetChecked(
loc, elements.size(), elements.data(), packed);
if (mlirTypeIsNull(type)) {
throw py::value_error(scope.takeMessage());
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
"loc"_a = py::none());
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"loc"_a.none() = nb::none());

llvmStructType.def_classmethod(
"get_identified",
[](py::object cls, const std::string &name, MlirContext context) {
[](nb::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
"cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());

llvmStructType.def_classmethod(
"get_opaque",
[](py::object cls, const std::string &name, MlirContext context) {
[](nb::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
"cls"_a, "name"_a, "context"_a = py::none());
"cls"_a, "name"_a, "context"_a.none() = nb::none());

llvmStructType.def(
"set_body",
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
self, elements.size(), elements.data(), packed);
if (!mlirLogicalResultIsSuccess(result)) {
throw py::value_error(
throw nb::value_error(
"Struct body already set to different content.");
}
},
"elements"_a, py::kw_only(), "packed"_a = false);
"elements"_a, nb::kw_only(), "packed"_a = false);

llvmStructType.def_classmethod(
"new_identified",
[](py::object cls, const std::string &name,
[](nb::object cls, const std::string &name,
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
return cls(mlirLLVMStructTypeIdentifiedNewGet(
ctx, mlirStringRefCreate(name.data(), name.length()),
elements.size(), elements.data(), packed));
},
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
"context"_a = py::none());
"cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"context"_a.none() = nb::none());

llvmStructType.def_property_readonly(
"name", [](MlirType type) -> std::optional<std::string> {
Expand All @@ -93,12 +97,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
return StringRef(stringRef.data, stringRef.length).str();
});

llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
// Don't crash in absence of a body.
if (mlirLLVMStructTypeIsOpaque(type))
return py::none();
return nb::none();

py::list body;
nb::list body;
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
++i) {
body.append(mlirLLVMStructTypeGetElementType(type, i));
Expand All @@ -119,24 +123,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
.def_classmethod(
"get",
[](py::object cls, std::optional<unsigned> addressSpace,
[](nb::object cls, std::optional<unsigned> addressSpace,
MlirContext context) {
CollectDiagnosticsToStringScope scope(context);
MlirType type = mlirLLVMPointerTypeGet(
context, addressSpace.has_value() ? *addressSpace : 0);
if (mlirTypeIsNull(type)) {
throw py::value_error(scope.takeMessage());
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
"context"_a = py::none())
"cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
"context"_a.none() = nb::none())
.def_property_readonly("address_space", [](MlirType type) {
return mlirLLVMPointerTypeGetAddressSpace(type);
});
}

PYBIND11_MODULE(_mlirDialectsLLVM, m) {
NB_MODULE(_mlirDialectsLLVM, m) {
m.doc() = "MLIR LLVM Dialect";

populateDialectLLVMSubmodule(m);
Expand Down
Loading
Loading