-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][transform][python] add sugared python abstractions for transform dialect #75073
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: None (martin-luecke) ChangesThis adds Python abstractions for the different handle types of the transform dialect The abstractions allow for straightforward chaining of transforms by calling their member functions. def script(module: OpHandle):
module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32]) to generate the following IR: %0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32] These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains. Full diff: https://github.com/llvm/llvm-project/pull/75073.diff 3 Files Affected:
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc2633..8013b49dbf9d6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
dialects/rocdl.py
DIALECT_NAME rocdl)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.extras
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ GEN_ENUM_BINDINGS
+ SOURCES
+ dialects/transform/extras/__init__.py)
+
declare_mlir_python_sources(
MLIRPythonSources.Dialects.quant
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 0000000000000..9f1f752bd7dba
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence
+
+try:
+ from .... import ir
+ from ....dialects import transform
+ from ....dialects.transform import structured
+except ImportError as e:
+ raise RuntimeError("Error loading imports") from e
+
+
+@dataclass
+class Value(abc.ABC):
+ """Wrapper around a transform value handle with methods to chain further transforms."""
+
+ _mlir_value: ir.Value
+ children: list[Value] = field(default_factory=list)
+ parent: Optional[Value] = None
+
+ @property
+ def mlir_value(self) -> ir.Value:
+ return self._mlir_value
+
+
+@dataclass
+class Param(Value):
+ """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+@dataclass
+class OpHandle(Value):
+ """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+ def match_ops(
+ self,
+ ops: str
+ | ir.OpView
+ | structured.MatchInterfaceEnum
+ | Sequence[str | ir.OpView],
+ ) -> OpHandle:
+ """
+ Returns a handle to ops that match the given names, types, or interface.
+ If only a single type is given, the value wrapped by the resulting
+ handle is populated with the respective type.
+ """
+ # Handle interface.
+ if isinstance(ops, structured.MatchInterfaceEnum) or (
+ isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+ ):
+ if isinstance(ops, str):
+ ops = structured.MatchInterfaceEnum[ops]
+ match_op = structured.MatchOp(
+ transform.AnyOpType.get(),
+ self.mlir_value,
+ interface=ops,
+ )
+
+ # Handle op name(s), either given directly as string or given as op.
+ else:
+ if isinstance(ops, str):
+ op_type = transform.OperationType.get(ops)
+ op_names = [ops]
+ elif isinstance(ops, Sequence):
+ op_type = transform.AnyOpType.get()
+ op_names = [
+ op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+ ]
+ else:
+ op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_names = [ops.OPERATION_NAME]
+ match_op = structured.MatchOp.match_op_names(
+ op_type,
+ self.mlir_value,
+ op_names,
+ )
+
+ handle = OpHandle(match_op.results_, parent=self)
+ self.children.append(handle)
+ return handle
+
+
+def insert_transform_script(
+ module: ir.Module,
+ script: Callable[[OpHandle], None],
+ dump_script: bool = False,
+) -> None:
+ """
+ Inserts the transform script of the schedule into the module. The script
+ should accept an instance of OpHandle as argument, which will be called with
+ the block arg of the newly created sequence op.
+
+ Example:
+ This python code
+ ```
+ module = ir.Module.create()
+ def test_match_ops_single(module: OpHandle):
+ module.match_ops(scf.ForOp)
+ insert_transform_script(module, script)
+ ```
+ generates the following IR:
+ ```
+ module {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+ }
+ }
+ ```
+ """
+
+ with module.context, ir.Location.unknown(module.context):
+ with ir.InsertionPoint.at_block_begin(module.body):
+ sequence_op = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ (),
+ transform.AnyOpType.get(),
+ )
+ with ir.InsertionPoint(sequence_op.body):
+ script(OpHandle(sequence_op.bodyTarget))
+ transform.YieldOp([])
+
+ if dump_script:
+ print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 0000000000000..08d853d8c2bc7
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,78 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+
+
+def build_transform_script(script: Callable[[OpHandle], None]):
+ print("\nTEST:", script.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ insert_transform_script(module, script=script, dump_script=True)
+ module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+@build_transform_script
+def test_match_ops_single(op: OpHandle):
+ op.match_ops(scf.ForOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+ # CHECK-SAME: in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+@build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+ op.match_ops("linalg.matmul")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+@build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+ op.match_ops("LinalgOp")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+@build_transform_script
+def test_match_ops_iface(op: OpHandle):
+ op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+@build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+ op.match_ops([scf.ForOp, scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+@build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+ op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is consistent with other recently added usability extras. LGTM when transitioned to using named sequences. Please give Maks a change to review.
I asked for changes but I'm not opinionated on them - I think the changes I proposed are good but maybe you guys have already considered and discounted them - so if you guys don't agree I don't have any problem approving. |
9bce456
to
7ee6bb0
Compare
- expose `get_static_typeid` for Python mlir type subclasses - expose typeid getters for transform dialect types to CAPI - sequence -> named_sequence - more general `insert_transform_script` - use `value_caster` automation instead of composition - clarifying documentation - cleaning up imports - add missing license header
isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ | ||
): | ||
if isinstance(ops, str): | ||
ops = structured.MatchInterfaceEnum[ops] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm makes me think we should change the emitted enums to handle strings as well like this (not in this PR obv).
ops: str | ||
| ir.OpView | ||
| structured.MatchInterfaceEnum | ||
| Sequence[str | ir.OpView], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This breaks compat with < 3.10. Personally I'd be happy to bump to 3.10 but prior we've tried to maintain down to 3.8 (I guess buildbots have upgraded because I've failed tests before with this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overall guidance is to have LLVM build on live LTS repos so it can be distributed. Ubuntu 20.04 is one of these, and it has python 3.8.2. Debian is surprisingly fresher, and I'm not aware what RHEL-based distros folks run these days. But sticking to 3.8 sounds reasonable. I don't know if we can import from future for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The future import is available since 3.7 and the from __future__ import annotations
import is exactly why this does not break the build in Python 3.8.
With this import the annotations here are not evaluated by the Python interpreter at module import time, but only when e.g. typing.get_type_hints
is called on it (which we do not). The annotations are still understood by tools like MyPy, see here.
This should not be an issue for us, but I can still move this over to the Union
syntax if you think this prudent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know this! Definitely keep the pipes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our policy to date has been to track the PSF's eol policy, not specific LTS. In any case, it is a distinction without a difference in this case: we support 3.8 until it is eol in Oct-2024.
For the older Python versions, I am fine with conditional coding or some level of heroic if the feature is very useful, but in most cases I see, it is just easier to write it in an older-version compatible way and not deal with the mental burden of conditional stuff.
Feel free to plus me in on anything that needs untangling in this vein. Happy to offer advice.
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking this from another patch where I saw this being added: I would very much prefer that transform dialect stuff be contained under its own dialect namespace, not in a parallel dialect tree. If it is experimental and needs more bake time, note that in the name or something, but let's not sprawl dialect specific things into other parts of the codebase.
This adds Python abstractions for the different handle types of the transform dialect
The abstractions allow for straightforward chaining of transforms by calling their member functions.
As an initial PR for this infrastructure, only a single transform is included:
transform.structured.match
.With a future
tile
transform abstraction an example of the usage is:to generate the following IR:
These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains.