-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][transform][python] add sugared python abstractions for transform dialect #75073
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ec81e0
86fb474
ee8f9c2
7ee6bb0
fd1ddd1
e2bc61f
bfa63bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from __future__ import annotations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Taking this from another patch where I saw this being added: I would very much prefer that transform dialect stuff be contained under its own dialect namespace, not in a parallel dialect tree. If it is experimental and needs more bake time, note that in the name or something, but let's not sprawl dialect specific things into other parts of the codebase. |
||
from typing import Callable, Optional, Sequence | ||
|
||
from .... import ir | ||
from ....dialects import transform | ||
from ....dialects.transform import structured | ||
|
||
|
||
class Handle(ir.Value): | ||
""" | ||
Base class for wrappers around different types of transform handle with | ||
methods to chain further transforms. | ||
|
||
The fields `children` and `parent` are used to capture the relation of | ||
handles statically in order to enable further analysis. The payload | ||
operation of a child handle is nested into a region of the payload operation | ||
of the corresponding parent handle. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
v: ir.Value, | ||
*, | ||
parent: Optional[Handle] = None, | ||
children: Optional[Sequence[Handle]] = None, | ||
): | ||
super().__init__(v) | ||
self.parent = parent | ||
self.children = children if children is not None else [] | ||
|
||
|
||
@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) | ||
@ir.register_value_caster(transform.OperationType.get_static_typeid()) | ||
class OpHandle(Handle): | ||
""" | ||
Wrapper around a transform operation handle with methods to chain further | ||
transforms. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
v: ir.Value, | ||
*, | ||
parent: Optional[Handle] = None, | ||
children: Optional[Sequence[Handle]] = None, | ||
): | ||
super().__init__(v, parent=parent, children=children) | ||
|
||
def match_ops( | ||
self, | ||
ops: str | ||
| ir.OpView | ||
| structured.MatchInterfaceEnum | ||
| Sequence[str | ir.OpView], | ||
Comment on lines
+55
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This breaks compat with < 3.10. Personally I'd be happy to bump to 3.10 but prior we've tried to maintain down to 3.8 (I guess buildbots have upgraded because I've failed tests before with this). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The overall guidance is to have LLVM build on live LTS repos so it can be distributed. Ubuntu 20.04 is one of these, and it has python 3.8.2. Debian is surprisingly fresher, and I'm not aware what RHEL-based distros folks run these days. But sticking to 3.8 sounds reasonable. I don't know if we can import from future for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The future import is available since 3.7 and the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't know this! Definitely keep the pipes! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our policy to date has been to track the PSF's eol policy, not specific LTS. In any case, it is a distinction without a difference in this case: we support 3.8 until it is eol in Oct-2024. For the older Python versions, I am fine with conditional coding or some level of heroic if the feature is very useful, but in most cases I see, it is just easier to write it in an older-version compatible way and not deal with the mental burden of conditional stuff. Feel free to plus me in on anything that needs untangling in this vein. Happy to offer advice. |
||
) -> OpHandle: | ||
""" | ||
Emits a `transform.structured.MatchOp`. | ||
Returns a handle to payload ops that match the given names, types, or | ||
interface. If only a single type is given, the value wrapped by the | ||
resulting handle is populated with the respective type. | ||
""" | ||
# Handle interface. | ||
if isinstance(ops, structured.MatchInterfaceEnum) or ( | ||
isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ | ||
): | ||
if isinstance(ops, str): | ||
ops = structured.MatchInterfaceEnum[ops] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm makes me think we should change the emitted enums to handle strings as well like this (not in this PR obv). |
||
match_op = structured.MatchOp( | ||
transform.AnyOpType.get(), | ||
self, | ||
interface=ops, | ||
) | ||
|
||
# Handle op name(s), either given directly as string or given as op. | ||
else: | ||
if isinstance(ops, str): | ||
op_type = transform.OperationType.get(ops) | ||
op_names = [ops] | ||
elif isinstance(ops, Sequence): | ||
op_type = transform.AnyOpType.get() | ||
op_names = [ | ||
op if isinstance(op, str) else op.OPERATION_NAME for op in ops | ||
] | ||
else: | ||
op_type = transform.OperationType.get(ops.OPERATION_NAME) | ||
op_names = [ops.OPERATION_NAME] | ||
match_op = structured.MatchOp.match_op_names( | ||
op_type, | ||
self, | ||
op_names, | ||
) | ||
|
||
handle = OpHandle(match_op.results_, parent=self) | ||
self.children.append(handle) | ||
return handle | ||
|
||
|
||
def insert_transform_script( | ||
block_or_insertion_point: ir.Block | ir.InsertionPoint, | ||
script: Callable[[OpHandle], None], | ||
dump_script: bool = False, | ||
) -> None: | ||
""" | ||
Inserts the transform script of the schedule into the module. The script | ||
should accept an instance of OpHandle as argument, which will be called with | ||
the block arg of the newly created named_sequence op. | ||
|
||
Example: | ||
This python code | ||
``` | ||
module = ir.Module.create() | ||
def test_match_ops_single(module: OpHandle): | ||
module.match_ops(scf.ForOp) | ||
insert_transform_script(module.body, script) | ||
``` | ||
generates the following IR: | ||
``` | ||
module { | ||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) { | ||
^bb0(%arg0: !transform.any_op): | ||
%0 = transform.structured.match ops{["scf.for"]} in %arg0 | ||
: (!transform.any_op) -> !transform.op<"scf.for"> | ||
} | ||
} | ||
``` | ||
""" | ||
if isinstance(block_or_insertion_point, ir.Block): | ||
context = block_or_insertion_point.owner.context | ||
insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point) | ||
else: | ||
context = block_or_insertion_point.block.owner.context | ||
insertion_point = block_or_insertion_point | ||
|
||
with context, ir.Location.unknown(context): | ||
with insertion_point: | ||
named_sequence_op = transform.NamedSequenceOp( | ||
"__transform_main", [transform.AnyOpType.get()], [] | ||
) | ||
with ir.InsertionPoint(named_sequence_op.body): | ||
script(named_sequence_op.bodyTarget) | ||
transform.YieldOp([]) | ||
|
||
if dump_script: | ||
print(named_sequence_op) |
Uh oh!
There was an error while loading. Please reload this page.