Skip to content

Enabled freethreading support in MLIR python bindings #122684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mlir/cmake/modules/AddMLIRPython.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,31 @@ function(add_mlir_python_extension libname extname)
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
nanobind_add_module(${libname}
NB_DOMAIN mlir
FREE_THREADED
${ARG_SOURCES}
)

if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
target_compile_options(nanobind-static
set(nanobind_target "nanobind-static")
if (NOT TARGET ${nanobind_target})
# Get correct nanobind target name: nanobind-static-ft or something else
# It is set by nanobind_add_module function according to the passed options
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)

# Iterate over the list of targets
foreach(target ${all_targets})
# Check if the target name matches the given string
if("${target}" MATCHES "nanobind-")
set(nanobind_target "${target}")
endif()
endforeach()

if (NOT TARGET ${nanobind_target})
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
endif()
endif()
target_compile_options(${nanobind_target}
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array
Expand Down
40 changes: 40 additions & 0 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -1187,3 +1187,43 @@ or nanobind and
utilities to connect to the rest of Python API. The bindings can be located in a
separate module or in the same module as attributes and types, and
loaded along with the dialect.

## Free-threading (No-GIL) support

Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).

MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:

```python
# python3.13t example.py
import concurrent.futures

import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint


def func(py_value):
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))

dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
arith.constant(dtype, py_value)

return module


num_workers = 8
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for i in range(num_workers):
futures.append(executor.submit(func, i))
assert len(list(f.result() for f in futures)) == num_workers
```

The exceptions to the free-threading compatibility:
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
#include "Standalone-c/Dialects.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

namespace py = pybind11;

using namespace mlir::python::adaptors;

PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
PYBIND11_MODULE(_standaloneDialectsPybind11, m, py::mod_gil_not_used()) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
Expand Down
12 changes: 11 additions & 1 deletion mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace mlir {
namespace python {

/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class PyGlobals {
public:
PyGlobals();
Expand All @@ -37,12 +38,18 @@ class PyGlobals {

/// Get and set the list of parent modules to search for dialect
/// implementation classes.
std::vector<std::string> &getDialectSearchPrefixes() {
std::vector<std::string> getDialectSearchPrefixes() {
nanobind::ft_lock_guard lock(mutex);
return dialectSearchPrefixes;
}
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.swap(newValues);
}
void addDialectSearchPrefix(std::string value) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.push_back(std::move(value));
}

/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
Expand Down Expand Up @@ -109,6 +116,9 @@ class PyGlobals {

private:
static PyGlobals *instance;

nanobind::ft_mutex mutex;

/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
Expand Down
31 changes: 27 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,

/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
static void set(nb::object &o, bool enable) {
nb::ft_lock_guard lock(mutex);
mlirEnableGlobalDebug(enable);
}

static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
static bool get(const nb::object &) {
nb::ft_lock_guard lock(mutex);
return mlirIsGlobalDebugEnabled();
}

static void bind(nb::module_ &m) {
// Debug flags.
Expand All @@ -255,6 +261,7 @@ struct PyGlobalDebugFlag {
.def_static(
"set_types",
[](const std::string &type) {
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
"types"_a, "Sets specific debug types to be produced by LLVM")
Expand All @@ -263,11 +270,17 @@ struct PyGlobalDebugFlag {
pointers.reserve(types.size());
for (const std::string &str : types)
pointers.push_back(str.c_str());
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
});
}

private:
static nb::ft_mutex mutex;
};

nb::ft_mutex PyGlobalDebugFlag::mutex;

struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
Expand Down Expand Up @@ -606,6 +619,7 @@ class PyOpOperandIterator {

PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
Expand All @@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
nb::gil_scoped_acquire acquire;
getLiveContexts().erase(context.ptr);
{
nb::ft_lock_guard lock(live_contexts_mutex);
getLiveContexts().erase(context.ptr);
}
mlirContextDestroy(context);
}

Expand All @@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {

PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
Expand All @@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
return PyMlirContextRef(it->second, std::move(pyRef));
}

nb::ft_mutex PyMlirContext::live_contexts_mutex;

PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}

size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveCount() {
nb::ft_lock_guard lock(live_contexts_mutex);
return getLiveContexts().size();
}

size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }

Expand Down
18 changes: 16 additions & 2 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }

bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
if (loadedDialectModules.contains(dialectNamespace))
return true;
{
nb::ft_lock_guard lock(mutex);
if (loadedDialectModules.contains(dialectNamespace))
return true;
}
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
nb::object loaded = nb::none();
Expand All @@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
nb::ft_lock_guard lock(mutex);
loadedDialectModules.insert(dialectNamespace);
return true;
}

void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
nb::callable pyFunc, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
Expand All @@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,

void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
nb::callable typeCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
Expand All @@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,

void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
nb::callable valueCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
Expand All @@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
nb::object pyClass) {
nb::ft_lock_guard lock(mutex);
nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
Expand All @@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,

void PyGlobals::registerOperationImpl(const std::string &operationName,
nb::object pyClass, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
Expand All @@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,

std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
nb::ft_lock_guard lock(mutex);
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
assert(foundIt->second && "attribute builder is defined");
Expand All @@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
assert(foundIt->second && "type caster is defined");
Expand All @@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
Expand All @@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
nb::ft_lock_guard lock(mutex);
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
assert(foundIt->second && "dialect class is defined");
Expand All @@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
if (!loadDialectModule(dialectNamespace))
return std::nullopt;

nb::ft_lock_guard lock(mutex);
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
assert(foundIt->second && "OpView is defined");
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class PyMlirContext {
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();

// Interns all live modules associated with this context. Modules tracked
Expand Down
9 changes: 2 additions & 7 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) {
.def_prop_rw("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def(
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
},
"module_name"_a)
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
"module_name"_a)
.def(
"_check_dialect_module_loaded",
[](PyGlobals &self, const std::string &dialectNamespace) {
Expand Down Expand Up @@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) {
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);

// Dict-stuff the new opClass by name onto the dialect class.
nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
Expand Down
3 changes: 2 additions & 1 deletion mlir/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
Loading
Loading