diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h index 91c99b1f869f2..02c99b5921882 100644 --- a/mlir/include/mlir-c/Dialect/Transform.h +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type); diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 5e0e56fc00a67..66cf20e1c136f 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -495,6 +495,13 @@ class mlir_type_subclass : public pure_subclass { .attr("replace")(superCls.attr("__name__"), captureTypeName); }); if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( getTypeIDFunction())(pybind11::cpp_function( diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index c7764f4e7aeca..6b57e652aa9d8 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto anyOpType = - mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType); + mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType, + mlirTransformAnyOpTypeGetTypeID); anyOpType.def_classmethod( "get", [](py::object cls, MlirContext ctx) { @@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto anyParamType = - mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType); + mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType, + mlirTransformAnyParamTypeGetTypeID); anyParamType.def_classmethod( "get", [](py::object cls, MlirContext ctx) { @@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto anyValueType = - mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType); + mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType, + mlirTransformAnyValueTypeGetTypeID); anyValueType.def_classmethod( "get", [](py::object cls, MlirContext ctx) { @@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto paramType = - mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType); + mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType, + mlirTransformParamTypeGetTypeID); paramType.def_classmethod( "get", [](py::object cls, MlirType type, MlirContext ctx) { diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index 3f7f8b8e2113f..5fd773572bd3c 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) { + return wrap(transform::AnyOpType::getTypeID()); +} + MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { return wrap(transform::AnyOpType::get(unwrap(ctx))); } @@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) { + return wrap(transform::AnyParamType::getTypeID()); +} + MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) { return wrap(transform::AnyParamType::get(unwrap(ctx))); } @@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) { + return wrap(transform::AnyValueType::getTypeID()); +} + MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) { return wrap(transform::AnyValueType::get(unwrap(ctx))); } @@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { } //===---------------------------------------------------------------------===// -// AnyOpType +// ParamType //===---------------------------------------------------------------------===// bool mlirTypeIsATransformParamType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformParamTypeGetTypeID(void) { + return wrap(transform::ParamType::getTypeID()); +} + MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) { return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type))); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 585918afc2633..41d91cf677833 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings( "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td" ) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.extras + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS + SOURCES + extras/dialects/transform/__init__.py) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/extras/dialects/transform/__init__.py new file mode 100644 index 0000000000000..9e313324318aa --- /dev/null +++ b/mlir/python/mlir/extras/dialects/transform/__init__.py @@ -0,0 +1,148 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from __future__ import annotations +from typing import Callable, Optional, Sequence + +from .... import ir +from ....dialects import transform +from ....dialects.transform import structured + + +class Handle(ir.Value): + """ + Base class for wrappers around different types of transform handle with + methods to chain further transforms. + + The fields `children` and `parent` are used to capture the relation of + handles statically in order to enable further analysis. The payload + operation of a child handle is nested into a region of the payload operation + of the corresponding parent handle. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v) + self.parent = parent + self.children = children if children is not None else [] + + +@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) +@ir.register_value_caster(transform.OperationType.get_static_typeid()) +class OpHandle(Handle): + """ + Wrapper around a transform operation handle with methods to chain further + transforms. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + def match_ops( + self, + ops: str + | ir.OpView + | structured.MatchInterfaceEnum + | Sequence[str | ir.OpView], + ) -> OpHandle: + """ + Emits a `transform.structured.MatchOp`. + Returns a handle to payload ops that match the given names, types, or + interface. If only a single type is given, the value wrapped by the + resulting handle is populated with the respective type. + """ + # Handle interface. + if isinstance(ops, structured.MatchInterfaceEnum) or ( + isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ + ): + if isinstance(ops, str): + ops = structured.MatchInterfaceEnum[ops] + match_op = structured.MatchOp( + transform.AnyOpType.get(), + self, + interface=ops, + ) + + # Handle op name(s), either given directly as string or given as op. + else: + if isinstance(ops, str): + op_type = transform.OperationType.get(ops) + op_names = [ops] + elif isinstance(ops, Sequence): + op_type = transform.AnyOpType.get() + op_names = [ + op if isinstance(op, str) else op.OPERATION_NAME for op in ops + ] + else: + op_type = transform.OperationType.get(ops.OPERATION_NAME) + op_names = [ops.OPERATION_NAME] + match_op = structured.MatchOp.match_op_names( + op_type, + self, + op_names, + ) + + handle = OpHandle(match_op.results_, parent=self) + self.children.append(handle) + return handle + + +def insert_transform_script( + block_or_insertion_point: ir.Block | ir.InsertionPoint, + script: Callable[[OpHandle], None], + dump_script: bool = False, +) -> None: + """ + Inserts the transform script of the schedule into the module. The script + should accept an instance of OpHandle as argument, which will be called with + the block arg of the newly created named_sequence op. + + Example: + This python code + ``` + module = ir.Module.create() + def test_match_ops_single(module: OpHandle): + module.match_ops(scf.ForOp) + insert_transform_script(module.body, script) + ``` + generates the following IR: + ``` + module { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + ^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg0 + : (!transform.any_op) -> !transform.op<"scf.for"> + } + } + ``` + """ + if isinstance(block_or_insertion_point, ir.Block): + context = block_or_insertion_point.owner.context + insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point) + else: + context = block_or_insertion_point.block.owner.context + insertion_point = block_or_insertion_point + + with context, ir.Location.unknown(context): + with insertion_point: + named_sequence_op = transform.NamedSequenceOp( + "__transform_main", [transform.AnyOpType.get()], [] + ) + with ir.InsertionPoint(named_sequence_op.body): + script(named_sequence_op.bodyTarget) + transform.YieldOp([]) + + if dump_script: + print(named_sequence_op) diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py new file mode 100644 index 0000000000000..dbfa8a2dc73c4 --- /dev/null +++ b/mlir/test/python/dialects/transform_extras.py @@ -0,0 +1,95 @@ +# RUN: %PYTHON %s | FileCheck %s + +from typing import Callable +from mlir import ir +from mlir.dialects import scf +from mlir.dialects.transform import structured +from mlir.extras.dialects.transform import OpHandle, insert_transform_script + + +def build_transform_script(script: Callable[[OpHandle], None]): + print("\nTEST:", script.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + insert_transform_script(module.body, script=script, dump_script=True) + module.operation.verify() + + +def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]): + print("\nTEST:", script.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + insert_transform_script( + ir.InsertionPoint.at_block_begin(module.body), + script=script, + dump_script=True, + ) + module.operation.verify() + + +# CHECK-LABEL: TEST: test_build_script_at_insertion_point +@build_transform_script_at_insertion_point +def test_build_script_at_insertion_point(op: OpHandle): + pass + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: transform.yield + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_match_ops_single +@build_transform_script +def test_match_ops_single(op: OpHandle): + op.match_ops(scf.ForOp) + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[VAL_0]] + # CHECK-SAME: -> !transform.op<"scf.for"> + + +# CHECK-LABEL: TEST: test_match_ops_string_name +@build_transform_script +def test_match_ops_string_name(op: OpHandle): + op.match_ops("linalg.matmul") + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_match_ops_string_iface +@build_transform_script +def test_match_ops_string_iface(op: OpHandle): + op.match_ops("LinalgOp") + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_match_ops_iface +@build_transform_script +def test_match_ops_iface(op: OpHandle): + op.match_ops(structured.MatchInterfaceEnum.LinalgOp) + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_match_ops_multiple +@build_transform_script +def test_match_ops_multiple(op: OpHandle): + op.match_ops([scf.ForOp, scf.ForallOp]) + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]] + # CHECK-SAME: -> !transform.any_op + + +# CHECK-LABEL: TEST: test_match_ops_mixed +@build_transform_script +def test_match_ops_mixed(op: OpHandle): + op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp]) + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]] + # CHECK-SAME: -> !transform.any_op