-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][python] move transform extras #76102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesAddresses #75073 (comment). Full diff: https://github.com/llvm/llvm-project/pull/76102.diff 4 Files Affected:
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
- extras/dialects/transform/__init__.py)
+ dialects/transform/extras/__init__.py)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
try:
from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 80%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..c715dac1ef7eb8 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,12 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from __future__ import annotations
-from typing import Callable, Optional, Sequence
+from typing import Callable, Optional, Sequence, Union
from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
class Handle(ir.Value):
@@ -25,16 +24,16 @@ def __init__(
self,
v: ir.Value,
*,
- parent: Optional[Handle] = None,
- children: Optional[Sequence[Handle]] = None,
+ parent: Optional["Handle"] = None,
+ children: Optional[Sequence["Handle"]] = None,
):
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []
-@ir.register_value_caster(transform.AnyOpType.get_static_typeid())
-@ir.register_value_caster(transform.OperationType.get_static_typeid())
+@ir.register_value_caster(AnyOpType.get_static_typeid())
+@ir.register_value_caster(OperationType.get_static_typeid())
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
def match_ops(
self,
- ops: str
- | ir.OpView
- | structured.MatchInterfaceEnum
- | Sequence[str | ir.OpView],
- ) -> OpHandle:
+ ops: Union[
+ str,
+ ir.OpView,
+ structured.MatchInterfaceEnum,
+ Sequence[Union[str, ir.OpView]],
+ ],
+ ) -> "OpHandle":
"""
Emits a `transform.structured.MatchOp`.
Returns a handle to payload ops that match the given names, types, or
@@ -70,7 +71,7 @@ def match_ops(
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
match_op = structured.MatchOp(
- transform.AnyOpType.get(),
+ AnyOpType.get(),
self,
interface=ops,
)
@@ -78,15 +79,15 @@ def match_ops(
# Handle op name(s), either given directly as string or given as op.
else:
if isinstance(ops, str):
- op_type = transform.OperationType.get(ops)
+ op_type = OperationType.get(ops)
op_names = [ops]
elif isinstance(ops, Sequence):
- op_type = transform.AnyOpType.get()
+ op_type = AnyOpType.get()
op_names = [
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
]
else:
- op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_type = OperationType.get(ops.OPERATION_NAME)
op_names = [ops.OPERATION_NAME]
match_op = structured.MatchOp.match_op_names(
op_type,
@@ -100,7 +101,7 @@ def match_ops(
def insert_transform_script(
- block_or_insertion_point: ir.Block | ir.InsertionPoint,
+ block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
with context, ir.Location.unknown(context):
with insertion_point:
- named_sequence_op = transform.NamedSequenceOp(
- "__transform_main", [transform.AnyOpType.get()], []
+ named_sequence_op = NamedSequenceOp(
+ "__transform_main", [AnyOpType.get()], []
)
with ir.InsertionPoint(named_sequence_op.body):
script(named_sequence_op.bodyTarget)
- transform.YieldOp([])
+ YieldOp([])
if dump_script:
print(named_sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
def build_transform_script(script: Callable[[OpHandle], None]):
|
ftynse
approved these changes
Dec 20, 2023
martin-luecke
approved these changes
Dec 20, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Addresses #75073 (comment).