Skip to content

[mlir][python] move transform extras #76102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2023

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Dec 20, 2023

Addresses #75073 (comment).

@makslevental makslevental requested review from martin-luecke, stellaraccident and ftynse and removed request for martin-luecke December 20, 2023 21:30
@makslevental makslevental marked this pull request as ready for review December 20, 2023 21:31
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Dec 20, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Addresses #75073 (comment).


Full diff: https://github.com/llvm/llvm-project/pull/76102.diff

4 Files Affected:

  • (modified) mlir/python/CMakeLists.txt (+1-1)
  • (modified) mlir/python/mlir/dialects/transform/init.py (+1)
  • (renamed) mlir/python/mlir/dialects/transform/extras/init.py (+22-21)
  • (modified) mlir/test/python/dialects/transform_extras.py (+1-1)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   GEN_ENUM_BINDINGS
   SOURCES
-    extras/dialects/transform/__init__.py)
+    dialects/transform/extras/__init__.py)
 
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
 from .._transform_ops_gen import *
 from .._transform_ops_gen import _Dialect
 from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
 
 try:
     from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 80%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..c715dac1ef7eb8 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,12 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from __future__ import annotations
-from typing import Callable, Optional, Sequence
+from typing import Callable, Optional, Sequence, Union
 
 from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
 
 
 class Handle(ir.Value):
@@ -25,16 +24,16 @@ def __init__(
         self,
         v: ir.Value,
         *,
-        parent: Optional[Handle] = None,
-        children: Optional[Sequence[Handle]] = None,
+        parent: Optional["Handle"] = None,
+        children: Optional[Sequence["Handle"]] = None,
     ):
         super().__init__(v)
         self.parent = parent
         self.children = children if children is not None else []
 
 
-@ir.register_value_caster(transform.AnyOpType.get_static_typeid())
-@ir.register_value_caster(transform.OperationType.get_static_typeid())
+@ir.register_value_caster(AnyOpType.get_static_typeid())
+@ir.register_value_caster(OperationType.get_static_typeid())
 class OpHandle(Handle):
     """
     Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
 
     def match_ops(
         self,
-        ops: str
-        | ir.OpView
-        | structured.MatchInterfaceEnum
-        | Sequence[str | ir.OpView],
-    ) -> OpHandle:
+        ops: Union[
+            str,
+            ir.OpView,
+            structured.MatchInterfaceEnum,
+            Sequence[Union[str, ir.OpView]],
+        ],
+    ) -> "OpHandle":
         """
         Emits a `transform.structured.MatchOp`.
         Returns a handle to payload ops that match the given names, types, or
@@ -70,7 +71,7 @@ def match_ops(
             if isinstance(ops, str):
                 ops = structured.MatchInterfaceEnum[ops]
             match_op = structured.MatchOp(
-                transform.AnyOpType.get(),
+                AnyOpType.get(),
                 self,
                 interface=ops,
             )
@@ -78,15 +79,15 @@ def match_ops(
         # Handle op name(s), either given directly as string or given as op.
         else:
             if isinstance(ops, str):
-                op_type = transform.OperationType.get(ops)
+                op_type = OperationType.get(ops)
                 op_names = [ops]
             elif isinstance(ops, Sequence):
-                op_type = transform.AnyOpType.get()
+                op_type = AnyOpType.get()
                 op_names = [
                     op if isinstance(op, str) else op.OPERATION_NAME for op in ops
                 ]
             else:
-                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_type = OperationType.get(ops.OPERATION_NAME)
                 op_names = [ops.OPERATION_NAME]
             match_op = structured.MatchOp.match_op_names(
                 op_type,
@@ -100,7 +101,7 @@ def match_ops(
 
 
 def insert_transform_script(
-    block_or_insertion_point: ir.Block | ir.InsertionPoint,
+    block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
     script: Callable[[OpHandle], None],
     dump_script: bool = False,
 ) -> None:
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
 
     with context, ir.Location.unknown(context):
         with insertion_point:
-            named_sequence_op = transform.NamedSequenceOp(
-                "__transform_main", [transform.AnyOpType.get()], []
+            named_sequence_op = NamedSequenceOp(
+                "__transform_main", [AnyOpType.get()], []
             )
         with ir.InsertionPoint(named_sequence_op.body):
             script(named_sequence_op.bodyTarget)
-            transform.YieldOp([])
+            YieldOp([])
 
     if dump_script:
         print(named_sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
 
 
 def build_transform_script(script: Callable[[OpHandle], None]):

@makslevental makslevental merged commit acaff70 into llvm:main Dec 20, 2023
@makslevental makslevental deleted the move_transform_extras branch December 20, 2023 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants