Skip to content

[mlir python] Port Python core code to nanobind. #118583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/cmake/modules/MLIRDetectPythonEnv.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
"extension = '${PYTHON_MODULE_EXTENSION}")

mlir_detect_nanobind_install()
find_package(nanobind 2.2 CONFIG REQUIRED)
find_package(nanobind 2.4 CONFIG REQUIRED)
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Bindings/Python/IRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H

#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace mlir {

Expand Down
10 changes: 4 additions & 6 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_staticmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) = py::staticmethod(cf);
return *this;
}
Expand All @@ -387,9 +386,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_classmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) =
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
return *this;
Expand Down
39 changes: 19 additions & 20 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include "PybindUtils.h"
#include <optional>
#include <string>
#include <vector>

#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"

#include <optional>
#include <string>
#include <vector>

namespace mlir {
namespace python {

Expand Down Expand Up @@ -57,71 +56,71 @@ class PyGlobals {
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc,
nanobind::callable pyFunc,
bool replace = false);

/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
bool replace = false);

/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
nanobind::callable valueCaster,
bool replace = false);

/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
pybind11::object pyClass);
nanobind::object pyClass);

/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass, bool replace = false);
nanobind::object pyClass, bool replace = false);

/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
std::optional<nanobind::callable>
lookupAttributeBuilder(const std::string &attributeKind);

/// Returns the custom type caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
std::optional<nanobind::object>
lookupDialectClass(const std::string &dialectNamespace);

/// Looks up a registered operation class (deriving from OpView) by operation
/// name. Note that this may trigger a load of the dialect, which can
/// arbitrarily re-enter.
std::optional<pybind11::object>
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);

private:
static PyGlobals *instance;
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
llvm::StringMap<pybind11::object> dialectClassMap;
llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
llvm::StringMap<nanobind::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
Expand Down
Loading
Loading