diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 41d91cf677833..55c5973e40e52 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -172,7 +172,7 @@ declare_mlir_python_sources( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" GEN_ENUM_BINDINGS SOURCES - extras/dialects/transform/__init__.py) + dialects/transform/extras/__init__.py) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 7ae4fefbac412..175634c7d458f 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -6,6 +6,7 @@ from .._transform_ops_gen import * from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * +from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType try: from ...ir import * diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py similarity index 80% rename from mlir/python/mlir/extras/dialects/transform/__init__.py rename to mlir/python/mlir/dialects/transform/extras/__init__.py index 9e313324318aa..c715dac1ef7eb 100644 --- a/mlir/python/mlir/extras/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -2,12 +2,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from __future__ import annotations -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Union from .... import ir -from ....dialects import transform -from ....dialects.transform import structured +from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp +from .. import structured class Handle(ir.Value): @@ -25,16 +24,16 @@ def __init__( self, v: ir.Value, *, - parent: Optional[Handle] = None, - children: Optional[Sequence[Handle]] = None, + parent: Optional["Handle"] = None, + children: Optional[Sequence["Handle"]] = None, ): super().__init__(v) self.parent = parent self.children = children if children is not None else [] -@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) -@ir.register_value_caster(transform.OperationType.get_static_typeid()) +@ir.register_value_caster(AnyOpType.get_static_typeid()) +@ir.register_value_caster(OperationType.get_static_typeid()) class OpHandle(Handle): """ Wrapper around a transform operation handle with methods to chain further @@ -52,11 +51,13 @@ def __init__( def match_ops( self, - ops: str - | ir.OpView - | structured.MatchInterfaceEnum - | Sequence[str | ir.OpView], - ) -> OpHandle: + ops: Union[ + str, + ir.OpView, + structured.MatchInterfaceEnum, + Sequence[Union[str, ir.OpView]], + ], + ) -> "OpHandle": """ Emits a `transform.structured.MatchOp`. Returns a handle to payload ops that match the given names, types, or @@ -70,7 +71,7 @@ def match_ops( if isinstance(ops, str): ops = structured.MatchInterfaceEnum[ops] match_op = structured.MatchOp( - transform.AnyOpType.get(), + AnyOpType.get(), self, interface=ops, ) @@ -78,15 +79,15 @@ def match_ops( # Handle op name(s), either given directly as string or given as op. else: if isinstance(ops, str): - op_type = transform.OperationType.get(ops) + op_type = OperationType.get(ops) op_names = [ops] elif isinstance(ops, Sequence): - op_type = transform.AnyOpType.get() + op_type = AnyOpType.get() op_names = [ op if isinstance(op, str) else op.OPERATION_NAME for op in ops ] else: - op_type = transform.OperationType.get(ops.OPERATION_NAME) + op_type = OperationType.get(ops.OPERATION_NAME) op_names = [ops.OPERATION_NAME] match_op = structured.MatchOp.match_op_names( op_type, @@ -100,7 +101,7 @@ def match_ops( def insert_transform_script( - block_or_insertion_point: ir.Block | ir.InsertionPoint, + block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], script: Callable[[OpHandle], None], dump_script: bool = False, ) -> None: @@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle): with context, ir.Location.unknown(context): with insertion_point: - named_sequence_op = transform.NamedSequenceOp( - "__transform_main", [transform.AnyOpType.get()], [] + named_sequence_op = NamedSequenceOp( + "__transform_main", [AnyOpType.get()], [] ) with ir.InsertionPoint(named_sequence_op.body): script(named_sequence_op.bodyTarget) - transform.YieldOp([]) + YieldOp([]) if dump_script: print(named_sequence_op) diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py index dbfa8a2dc73c4..e7b43ea63c31c 100644 --- a/mlir/test/python/dialects/transform_extras.py +++ b/mlir/test/python/dialects/transform_extras.py @@ -4,7 +4,7 @@ from mlir import ir from mlir.dialects import scf from mlir.dialects.transform import structured -from mlir.extras.dialects.transform import OpHandle, insert_transform_script +from mlir.dialects.transform.extras import OpHandle, insert_transform_script def build_transform_script(script: Callable[[OpHandle], None]):