diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index f79c10cb93838..0a36e97c2ae68 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -118,13 +118,28 @@ /** Attribute on main C extension module (_mlir) that corresponds to the * type caster registration binding. The signature of the function is: - * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster, - * bool replace) - * where replace indicates the typeCaster should replace any existing registered - * type casters (such as those for upstream ConcreteTypes). + * def register_type_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a typeCaster (register_type_caster is meant to be used as a + * decorator from python), and where replace indicates the typeCaster should + * replace any existing registered type casters (such as those for upstream + * ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type) + * -> SubClassTypeT where SubClassTypeT indicates the result should be a + * subclass (inherit from) ir.Type. */ #define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster" +/** Attribute on main C extension module (_mlir) that corresponds to the + * value caster registration binding. The signature of the function is: + * def register_value_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a valueCaster (register_value_caster is meant to be used as + * a decorator, from python), and where replace indicates the valueCaster should + * replace any existing registered value casters. The interface of the + * valueCaster is: def value_caster(ir.Value) -> SubClassValueT where + * SubClassValueT indicates the result should be a subclass (inherit from) + * ir.Value. + */ +#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster" + /// Gets a void* from a wrapped struct. Needed because const cast is different /// between C/C++. #ifdef __cplusplus diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 49680c8b79b13..5e0e56fc00a67 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -234,6 +234,7 @@ struct type_caster { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Value") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); }; }; @@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass { if (getTypeIDFunction) { py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction(), - pybind11::cpp_function( - [thisClass = thisClass](const py::object &mlirType) { - return thisClass(mlirType); - })); + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirType) { + return thisClass(mlirType); + })); } } }; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 976297257ced0..a022067f5c7e5 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -66,6 +66,13 @@ class PyGlobals { void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, bool replace = false); + /// Adds a user-friendly value caster. Raises an exception if the mapping + /// already exists and replace == false. This is intended to be called by + /// implementation code. + void registerValueCaster(MlirTypeID mlirTypeID, + pybind11::function valueCaster, + bool replace = false); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -86,6 +93,10 @@ class PyGlobals { std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); + /// Returns the custom value caster for MlirTypeID mlirTypeID. + std::optional lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. std::optional @@ -109,7 +120,8 @@ class PyGlobals { llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. llvm::DenseMap typeCasterMap; - + /// Map of MlirTypeID to custom value caster. + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7cfea31dbb2e8..0f2ca666ccc05 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const { } //------------------------------------------------------------------------------ -// PyValue and subclases. +// PyValue and subclasses. //------------------------------------------------------------------------------ pybind11::object PyValue::getCapsule() { return py::reinterpret_steal(mlirPythonValueToCapsule(get())); } +pybind11::object PyValue::maybeDownCast() { + MlirType type = mlirValueGetType(get()); + MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional valueCaster = + PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); + // py::return_value_policy::move means use std::move to move the return value + // contents into a new instance that will be owned by Python. + py::object thisObj = py::cast(this, py::return_value_policy::move); + if (!valueCaster) + return thisObj; + return valueCaster.value()(thisObj); +} + PyValue PyValue::createFromCapsule(pybind11::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) @@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue { return DerivedTy::isaFunction(otherValue); }, py::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); } @@ -2193,6 +2210,7 @@ class PyBlockArgumentList : public Sliceable { public: static constexpr const char *pyClassName = "BlockArgumentList"; + using SliceableT = Sliceable; PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex = 0, intptr_t length = -1, @@ -2241,6 +2259,7 @@ class PyBlockArgumentList class PyOpOperandList : public Sliceable { public: static constexpr const char *pyClassName = "OpOperandList"; + using SliceableT = Sliceable; PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) @@ -2296,6 +2315,7 @@ class PyOpOperandList : public Sliceable { class PyOpResultList : public Sliceable { public: static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) @@ -2303,7 +2323,7 @@ class PyOpResultList : public Sliceable { length == -1 ? mlirOperationGetNumResults(operation->get()) : length, step), - operation(operation) {} + operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { c.def_property_readonly("types", [](PyOpResultList &self) { @@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) { .str()); } return PyOpResult(operation.getRef(), - mlirOperationGetResult(operation, 0)); + mlirOperationGetResult(operation, 0)) + .maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") @@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) { [](PyValue &self, PyValue &with) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, - kValueReplaceAllUsesWithDocstring); + kValueReplaceAllUsesWithDocstring) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) { return self.maybeDownCast(); }); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 6c5cde86236ce..5538924d24818 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, found = std::move(typeCaster); } +void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, + pybind11::function valueCaster, + bool replace) { + pybind11::object &found = valueCasterMap[mlirTypeID]; + if (found && !replace) + throw std::runtime_error("Value caster is already registered: " + + py::repr(found).cast()); + found = std::move(valueCaster); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -134,6 +144,17 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect) { + loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + const auto foundIt = valueCasterMap.find(mlirTypeID); + if (foundIt != valueCasterMap.end()) { + assert(foundIt->second && "value caster is defined"); + return foundIt->second; + } + return std::nullopt; +} + std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 01ee4975d0e9a..af55693f18fbb 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -761,7 +761,7 @@ class PyRegion { /// Wrapper around an MlirAsmState. class PyAsmState { - public: +public: PyAsmState(MlirValue value, bool useLocalScope) { flags = mlirOpPrintingFlagsCreate(); // The OpPrintingFlags are not exposed Python side, create locally and @@ -780,16 +780,14 @@ class PyAsmState { state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); } - ~PyAsmState() { - mlirOpPrintingFlagsDestroy(flags); - } + ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } // Delete copy constructors. PyAsmState(PyAsmState &other) = delete; PyAsmState(const PyAsmState &other) = delete; MlirAsmState get() { return state; } - private: +private: MlirAsmState state; MlirOpPrintingFlags flags; }; @@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy { /// bindings so such operation always exists). class PyValue { public: + // The virtual here is "load bearing" in that it enables RTTI + // for PyConcreteValue CRTP classes that support maybeDownCast. + // See PyValue::maybeDownCast. + virtual ~PyValue() = default; PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(std::move(parentOperation)), value(value) {} operator MlirValue() const { return value; } @@ -1124,6 +1126,8 @@ class PyValue { /// Gets a capsule wrapping the void* within the MlirValue. pybind11::object getCapsule(); + pybind11::object maybeDownCast(); + /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. static PyValue createFromCapsule(pybind11::object capsule); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 2ba3a3677198c..17272472ccca4 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,8 +12,6 @@ #include "IRModule.h" #include "Pass.h" -#include - namespace py = pybind11; using namespace mlir; using namespace py::literals; @@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, "replace"_a = false, + "operation_name"_a, "operation_class"_a, py::kw_only(), + "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - "dialect_class"_a, "replace"_a = false, + "dialect_class"_a, py::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) { - PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster), - replace); + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function([mlirTypeID, + replace](py::object typeCaster) -> py::object { + PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); + return typeCaster; + }); }, - "typeid"_a, "type_caster"_a, "replace"_a = false, + "typeid"_a, py::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); + m.def( + MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function( + [mlirTypeID, replace](py::object valueCaster) -> py::object { + PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, + replace); + return valueCaster; + }); + }, + "typeid"_a, py::kw_only(), "replace"_a = false, + "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 2a8da20bee049..38462ac8ba6db 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -10,6 +10,7 @@ #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #include "mlir-c/Support.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" @@ -228,6 +229,11 @@ class Sliceable { return linearIndex; } + /// Trait to check if T provides a `maybeDownCast` method. + /// Note, you need the & to detect inherited members. + template + using has_maybe_downcast = decltype(&T::maybeDownCast); + /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. @@ -239,8 +245,13 @@ class Sliceable { return {}; } - return pybind11::cast( - static_cast(this)->getRawElement(linearizeIndex(index))); + if constexpr (llvm::is_detected::value) + return static_cast(this) + ->getRawElement(linearizeIndex(index)) + .maybeDownCast(); + else + return pybind11::cast( + static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 9cca7d659ec8c..60ce83c09f171 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -5,7 +5,12 @@ # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. from .._mlir_libs import _mlir as _cext -from typing import Sequence as _Sequence, Union as _Union +from typing import ( + Sequence as _Sequence, + Type as _Type, + TypeVar as _TypeVar, + Union as _Union, +) __all__ = [ "equally_sized_accessor", @@ -123,3 +128,9 @@ def get_op_result_or_op_results( if len(op.results) > 0 else op ) + + +# This is the standard way to indicate subclass/inheritance relationship +# see the typing.Type doc string. +_U = _TypeVar("_U", bound=_cext.ir.Value) +SubClassValueT = _Type[_U] diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index cf4228c2a63a9..18526ab8c3c02 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,7 +4,7 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster +from ._mlir_libs._mlir import register_type_caster, register_value_caster # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 63dad1cc901fe..f7df8ba2df0ae 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> { } // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)) \ No newline at end of file +// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)) diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index 6d1c5eab75898..f80f2c084a0f3 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -1,8 +1,9 @@ # RUN: %PYTHON %s | FileCheck %s +from functools import partialmethod from mlir.ir import * -import mlir.dialects.func as func import mlir.dialects.arith as arith +import mlir.dialects.func as func def run(f): @@ -35,14 +36,59 @@ def testFastMathFlags(): print(r) -# CHECK-LABEL: TEST: testArithValueBuilder +# CHECK-LABEL: TEST: testArithValue @run -def testArithValueBuilder(): +def testArithValue(): + def _binary_op(lhs, rhs, op: str) -> "ArithValue": + op = op.capitalize() + if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): + op += "F" + elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( + lhs.type + ): + op += "I" + else: + raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") + + op = getattr(arith, f"{op}Op") + return op(lhs, rhs).result + + @register_value_caster(F16Type.static_typeid) + @register_value_caster(F32Type.static_typeid) + @register_value_caster(F64Type.static_typeid) + @register_value_caster(IntegerType.static_typeid) + class ArithValue(Value): + def __init__(self, v): + super().__init__(v) + + __add__ = partialmethod(_binary_op, op="add") + __sub__ = partialmethod(_binary_op, op="sub") + __mul__ = partialmethod(_binary_op, op="mul") + + def __str__(self): + return super().__str__().replace(Value.__name__, ArithValue.__name__) + with Context() as ctx, Location.unknown(): module = Module.create() + f16_t = F16Type.get() f32_t = F32Type.get() + f64_t = F64Type.get() with InsertionPoint(module.body): - a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) - # CHECK: %cst = arith.constant 4.242000e+01 : f32 + a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) + # CHECK: ArithValue(%cst = arith.constant 4.240 print(a) + + b = a + a + # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) + print(b) + + a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) + b = a - a + # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) + print(b) + + a = arith.constant(value=FloatAttr.get(f64_t, 42.42)) + b = a * a + # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) + print(b) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 472db7e5124db..f313a400b73c0 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -425,6 +425,12 @@ def __str__(self): # And it should be equal to the in-tree concrete type assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid + d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result + # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>) + print(d) + # CHECK: TestTensorValue + print(repr(d)) + # CHECK-LABEL: TEST: inferReturnTypeComponents @run @@ -502,19 +508,18 @@ def testCustomTypeTypeCaster(): # CHECK: Type caster is already registered try: + @register_type_caster(c.typeid) def type_caster(pytype): return test.TestIntegerRankedTensorType(pytype) - register_type_caster(c.typeid, type_caster) except RuntimeError as e: print(e) - def type_caster(pytype): - return RankedTensorType(pytype) - # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module. # So this one replaces that one (successfully). And then just to be sure we restore the original caster below. - register_type_caster(c.typeid, type_caster, replace=True) + @register_type_caster(c.typeid, replace=True) + def type_caster(pytype): + return RankedTensorType(pytype) d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result # CHECK: tensor<10x10xi5> @@ -522,11 +527,10 @@ def type_caster(pytype): # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>) print("ranked tensor type", repr(d.type)) + @register_type_caster(c.typeid, replace=True) def type_caster(pytype): return test.TestIntegerRankedTensorType(pytype) - register_type_caster(c.typeid, type_caster, replace=True) - d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result # CHECK: tensor<10x10xi5> print(d.type) diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index ddf653dcce278..acbf463113a6d 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -3,6 +3,7 @@ import gc from mlir.ir import * from mlir.dialects import func +from mlir.dialects._ods_common import SubClassValueT def run(f): @@ -270,3 +271,120 @@ def testValueSetType(): # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64 print(value.owner) + + +# CHECK-LABEL: TEST: testValueCasters +@run +def testValueCasters(): + class NOPResult(OpResult): + def __init__(self, v): + super().__init__(v) + + def __str__(self): + return super().__str__().replace(Value.__name__, NOPResult.__name__) + + class NOPValue(Value): + def __init__(self, v): + super().__init__(v) + + def __str__(self): + return super().__str__().replace(Value.__name__, NOPValue.__name__) + + class NOPBlockArg(BlockArgument): + def __init__(self, v): + super().__init__(v) + + def __str__(self): + return super().__str__().replace(Value.__name__, NOPBlockArg.__name__) + + @register_value_caster(IntegerType.static_typeid) + def cast_int(v) -> SubClassValueT: + print("in caster", v.__class__.__name__) + if isinstance(v, OpResult): + return NOPResult(v) + if isinstance(v, BlockArgument): + return NOPBlockArg(v) + elif isinstance(v, Value): + return NOPValue(v) + + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + values = Operation.create("custom.op1", results=[i32, i32]).results + # CHECK: in caster OpResult + # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", values[0].result_number, values[0]) + # CHECK: in caster OpResult + # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", values[1].result_number, values[1]) + + # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("results slice", values[:1][0].result_number, values[:1][0]) + + value0, value1 = values + # CHECK: in caster OpResult + # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", value0.result_number, values[0]) + # CHECK: in caster OpResult + # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", value1.result_number, values[1]) + + op1 = Operation.create("custom.op2", operands=[value0, value1]) + # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> () + print(op1) + + # CHECK: in caster Value + # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("operand 0", op1.operands[0]) + # CHECK: in caster Value + # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("operand 1", op1.operands[1]) + + # CHECK: in caster BlockArgument + # CHECK: in caster BlockArgument + @func.FuncOp.from_py_func(i32, i32) + def reduction(arg0, arg1): + # CHECK: as func arg 0 NOPBlockArg + print("as func arg", arg0.arg_number, arg0.__class__.__name__) + # CHECK: as func arg 1 NOPBlockArg + print("as func arg", arg1.arg_number, arg1.__class__.__name__) + + # CHECK: args slice 0 NOPBlockArg( of type 'i32' at index: 0) + print( + "args slice", + reduction.func_op.arguments[:1][0].arg_number, + reduction.func_op.arguments[:1][0], + ) + + try: + + @register_value_caster(IntegerType.static_typeid) + def dont_cast_int_shouldnt_register(v): + ... + + except RuntimeError as e: + # CHECK: Value caster is already registered: {{.*}}cast_int + print(e) + + @register_value_caster(IntegerType.static_typeid, replace=True) + def dont_cast_int(v) -> OpResult: + assert isinstance(v, OpResult) + print("don't cast", v.result_number, v) + return v + + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32) + new_value = Operation.create("custom.op1", results=[i32]).result + # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32) + print("result", new_value.result_number, new_value) + + # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32) + new_value = Operation.create("custom.op2", results=[i32]).results[0] + # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32) + print("result", new_value.result_number, new_value) diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index f533082a0a147..aff414894cb82 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -42,6 +42,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) { return cls(mlirPythonTestTestAttributeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, mlirPythonTestTestTypeGetTypeID) .def_classmethod( @@ -50,7 +51,8 @@ PYBIND11_MODULE(_mlirPythonTest, m) { return cls(mlirPythonTestTestTypeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); - auto cls = + + auto typeCls = mlir_type_subclass(m, "TestIntegerRankedTensorType", mlirTypeIsARankedIntegerTensor, py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) @@ -65,16 +67,40 @@ PYBIND11_MODULE(_mlirPythonTest, m) { encoding)); }, "cls"_a, "shape"_a, "width"_a, "context"_a = py::none()); - assert(py::hasattr(cls.get_class(), "static_typeid") && + + assert(py::hasattr(typeCls.get_class(), "static_typeid") && "TestIntegerRankedTensorType has no static_typeid"); - MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID(); + + MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); + + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID, + "replace"_a = true)( + pybind11::cpp_function([typeCls](const py::object &mlirType) { + return typeCls.get_class()(mlirType); + })); + + auto valueCls = mlir_value_subclass(m, "TestTensorValue", + mlirTypeIsAPythonTestTestTensorValue) + .def("is_null", [](MlirValue &self) { + return mlirValueIsNull(self); + }); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) { - return cls.get_class()(mlirType); - }), - /*replace=*/true); - mlir_value_subclass(m, "TestTensorValue", - mlirTypeIsAPythonTestTestTensorValue) - .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); }); + .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( + mlirRankedTensorTypeID)( + pybind11::cpp_function([valueCls](const py::object &valueObj) { + py::object capsule = mlirApiObjectToCapsule(valueObj); + MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); + MlirType t = mlirValueGetType(v); + // This is hyper-specific in order to exercise/test registering a + // value caster from cpp (but only for a single test case; see + // testTensorValue python_test.py). + if (mlirShapedTypeHasStaticShape(t) && + mlirShapedTypeGetDimSize(t, 0) == 1 && + mlirShapedTypeGetDimSize(t, 1) == 2 && + mlirShapedTypeGetDimSize(t, 2) == 3) + return valueCls.get_class()(valueObj); + return valueObj; + })); } diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index c8ef84721090a..0c0ad2cfeffdc 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -30,7 +30,15 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from ._ods_common import _cext as _ods_cext -from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results +from ._ods_common import ( + SubClassValueT as _SubClassValueT, + equally_sized_accessor as _ods_equally_sized_accessor, + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_op_results as _get_op_result_or_op_results, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + segmented_accessor as _ods_segmented_accessor, +) _ods_ir = _ods_cext.ir import builtins @@ -1004,8 +1012,8 @@ static void emitValueBuilder(const Operator &op, llvm::join(valueBuilderParams, ", "), llvm::join(opBuilderArgs, ", "), (op.getNumResults() > 1 - ? "_Sequence[_ods_ir.OpResult]" - : (op.getNumResults() > 0 ? "_ods_ir.OpResult" + ? "_Sequence[_SubClassValueT]" + : (op.getNumResults() > 0 ? "_SubClassValueT" : "_ods_ir.Operation"))); }