Skip to content

[MLIR][transform][python] add sugared python abstractions for transform dialect #75073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 15, 2023

Conversation

martin-luecke
Copy link
Contributor

This adds Python abstractions for the different handle types of the transform dialect

The abstractions allow for straightforward chaining of transforms by calling their member functions.
As an initial PR for this infrastructure, only a single transform is included: transform.structured.match.
With a future tile transform abstraction an example of the usage is:

def script(module: OpHandle):
    module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])

to generate the following IR:

%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]

These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Dec 11, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2023

@llvm/pr-subscribers-mlir

Author: None (martin-luecke)

Changes

This adds Python abstractions for the different handle types of the transform dialect

The abstractions allow for straightforward chaining of transforms by calling their member functions.
As an initial PR for this infrastructure, only a single transform is included: transform.structured.match.
With a future tile transform abstraction an example of the usage is:

def script(module: OpHandle):
    module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])

to generate the following IR:

%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]

These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains.


Full diff: https://github.com/llvm/llvm-project/pull/75073.diff

3 Files Affected:

  • (modified) mlir/python/CMakeLists.txt (+8)
  • (added) mlir/python/mlir/dialects/transform/extras/init.py (+126)
  • (added) mlir/test/python/dialects/transform_extras.py (+78)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc2633..8013b49dbf9d6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
     dialects/rocdl.py
   DIALECT_NAME rocdl)
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.extras
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
+  SOURCES
+    dialects/transform/extras/__init__.py)
+
 declare_mlir_python_sources(
   MLIRPythonSources.Dialects.quant
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 0000000000000..9f1f752bd7dba
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence
+
+try:
+    from .... import ir
+    from ....dialects import transform
+    from ....dialects.transform import structured
+except ImportError as e:
+    raise RuntimeError("Error loading imports") from e
+
+
+@dataclass
+class Value(abc.ABC):
+    """Wrapper around a transform value handle with methods to chain further transforms."""
+
+    _mlir_value: ir.Value
+    children: list[Value] = field(default_factory=list)
+    parent: Optional[Value] = None
+
+    @property
+    def mlir_value(self) -> ir.Value:
+        return self._mlir_value
+
+
+@dataclass
+class Param(Value):
+    """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+@dataclass
+class OpHandle(Value):
+    """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """
+        Returns a handle to ops that match the given names, types, or interface.
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self.mlir_value,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self.mlir_value,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+def insert_transform_script(
+    module: ir.Module,
+    script: Callable[[OpHandle], None],
+    dump_script: bool = False,
+) -> None:
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.sequence failures(propagate) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
+
+    with module.context, ir.Location.unknown(module.context):
+        with ir.InsertionPoint.at_block_begin(module.body):
+            sequence_op = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                (),
+                transform.AnyOpType.get(),
+            )
+        with ir.InsertionPoint(sequence_op.body):
+            script(OpHandle(sequence_op.bodyTarget))
+            transform.YieldOp([])
+
+    if dump_script:
+        print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 0000000000000..08d853d8c2bc7
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,78 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+
+
+def build_transform_script(script: Callable[[OpHandle], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        insert_transform_script(module, script=script, dump_script=True)
+        module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+@build_transform_script
+def test_match_ops_single(op: OpHandle):
+    op.match_ops(scf.ForOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+    # CHECK-SAME:    in %[[VAL_0]]
+    # CHECK-SAME:      -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+@build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+    op.match_ops("linalg.matmul")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+@build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+    op.match_ops("LinalgOp")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+@build_transform_script
+def test_match_ops_iface(op: OpHandle):
+    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+@build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+    op.match_ops([scf.ForOp, scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+@build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is consistent with other recently added usability extras. LGTM when transitioned to using named sequences. Please give Maks a change to review.

@makslevental
Copy link
Contributor

makslevental commented Dec 11, 2023

@ftynse @martin-luecke

I asked for changes but I'm not opinionated on them - I think the changes I proposed are good but maybe you guys have already considered and discounted them - so if you guys don't agree I don't have any problem approving.

- expose `get_static_typeid` for Python mlir type subclasses
- expose typeid getters for transform dialect types to CAPI
- sequence -> named_sequence
- more general `insert_transform_script`
- use `value_caster` automation instead of composition
- clarifying documentation
- cleaning up imports
- add missing license header
isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
):
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm makes me think we should change the emitted enums to handle strings as well like this (not in this PR obv).

Comment on lines +55 to +58
ops: str
| ir.OpView
| structured.MatchInterfaceEnum
| Sequence[str | ir.OpView],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks compat with < 3.10. Personally I'd be happy to bump to 3.10 but prior we've tried to maintain down to 3.8 (I guess buildbots have upgraded because I've failed tests before with this).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall guidance is to have LLVM build on live LTS repos so it can be distributed. Ubuntu 20.04 is one of these, and it has python 3.8.2. Debian is surprisingly fresher, and I'm not aware what RHEL-based distros folks run these days. But sticking to 3.8 sounds reasonable. I don't know if we can import from future for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The future import is available since 3.7 and the from __future__ import annotations import is exactly why this does not break the build in Python 3.8.
With this import the annotations here are not evaluated by the Python interpreter at module import time, but only when e.g. typing.get_type_hints is called on it (which we do not). The annotations are still understood by tools like MyPy, see here.
This should not be an issue for us, but I can still move this over to the Union syntax if you think this prudent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know this! Definitely keep the pipes!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our policy to date has been to track the PSF's eol policy, not specific LTS. In any case, it is a distinction without a difference in this case: we support 3.8 until it is eol in Oct-2024.

For the older Python versions, I am fine with conditional coding or some level of heroic if the feature is very useful, but in most cases I see, it is just easier to write it in an older-version compatible way and not deal with the mental burden of conditional stuff.

Feel free to plus me in on anything that needs untangling in this vein. Happy to offer advice.

@martin-luecke martin-luecke merged commit 681eacc into llvm:main Dec 15, 2023
@martin-luecke martin-luecke deleted the transform_py_extras branch December 15, 2023 12:48
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking this from another patch where I saw this being added: I would very much prefer that transform dialect stuff be contained under its own dialect namespace, not in a parallel dialect tree. If it is experimental and needs more bake time, note that in the name or something, but let's not sprawl dialect specific things into other parts of the codebase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants