Skip to content

[MLIR][transform][python] add sugared python abstractions for transform dialect #75073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir-c/Dialect/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
Expand All @@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
Expand All @@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
Expand All @@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
MlirType type);

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,13 @@ class mlir_type_subclass : public pure_subclass {
.attr("replace")(superCls.attr("__name__"), captureTypeName);
});
if (getTypeIDFunction) {
// 'get_static_typeid' method.
// This is modeled as a static method instead of a static property because
// `def_property_readonly_static` is not available in `pure_subclass` and
// we do not want to introduce the complexity that pybind uses to
// implement it.
def_staticmethod("get_static_typeid",
[getTypeIDFunction]() { return getTypeIDFunction(); });
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
getTypeIDFunction())(pybind11::cpp_function(
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Bindings/Python/DialectTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto anyOpType =
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
mlirTransformAnyOpTypeGetTypeID);
anyOpType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
Expand All @@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto anyParamType =
mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
mlirTransformAnyParamTypeGetTypeID);
anyParamType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
Expand All @@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto anyValueType =
mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
mlirTransformAnyValueTypeGetTypeID);
anyValueType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto paramType =
mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
mlirTransformParamTypeGetTypeID);
paramType.def_classmethod(
"get",
[](py::object cls, MlirType type, MlirContext ctx) {
Expand Down
18 changes: 17 additions & 1 deletion mlir/lib/CAPI/Dialect/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) {
return isa<transform::AnyOpType>(unwrap(type));
}

MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
return wrap(transform::AnyOpType::getTypeID());
}

MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
Expand All @@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) {
return isa<transform::AnyParamType>(unwrap(type));
}

MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
return wrap(transform::AnyParamType::getTypeID());
}

MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
Expand All @@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) {
return isa<transform::AnyValueType>(unwrap(type));
}

MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
return wrap(transform::AnyValueType::getTypeID());
}

MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
Expand Down Expand Up @@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
}

//===---------------------------------------------------------------------===//
// AnyOpType
// ParamType
//===---------------------------------------------------------------------===//

bool mlirTypeIsATransformParamType(MlirType type) {
return isa<transform::ParamType>(unwrap(type));
}

MlirTypeID mlirTransformParamTypeGetTypeID(void) {
return wrap(transform::ParamType::getTypeID());
}

MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings(
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
)

declare_mlir_python_sources(
MLIRPythonSources.Dialects.transform.extras
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
extras/dialects/transform/__init__.py)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
148 changes: 148 additions & 0 deletions mlir/python/mlir/extras/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking this from another patch where I saw this being added: I would very much prefer that transform dialect stuff be contained under its own dialect namespace, not in a parallel dialect tree. If it is experimental and needs more bake time, note that in the name or something, but let's not sprawl dialect specific things into other parts of the codebase.

from typing import Callable, Optional, Sequence

from .... import ir
from ....dialects import transform
from ....dialects.transform import structured


class Handle(ir.Value):
"""
Base class for wrappers around different types of transform handle with
methods to chain further transforms.

The fields `children` and `parent` are used to capture the relation of
handles statically in order to enable further analysis. The payload
operation of a child handle is nested into a region of the payload operation
of the corresponding parent handle.
"""

def __init__(
self,
v: ir.Value,
*,
parent: Optional[Handle] = None,
children: Optional[Sequence[Handle]] = None,
):
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []


@ir.register_value_caster(transform.AnyOpType.get_static_typeid())
@ir.register_value_caster(transform.OperationType.get_static_typeid())
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
transforms.
"""

def __init__(
self,
v: ir.Value,
*,
parent: Optional[Handle] = None,
children: Optional[Sequence[Handle]] = None,
):
super().__init__(v, parent=parent, children=children)

def match_ops(
self,
ops: str
| ir.OpView
| structured.MatchInterfaceEnum
| Sequence[str | ir.OpView],
Comment on lines +55 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks compat with < 3.10. Personally I'd be happy to bump to 3.10 but prior we've tried to maintain down to 3.8 (I guess buildbots have upgraded because I've failed tests before with this).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall guidance is to have LLVM build on live LTS repos so it can be distributed. Ubuntu 20.04 is one of these, and it has python 3.8.2. Debian is surprisingly fresher, and I'm not aware what RHEL-based distros folks run these days. But sticking to 3.8 sounds reasonable. I don't know if we can import from future for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The future import is available since 3.7 and the from __future__ import annotations import is exactly why this does not break the build in Python 3.8.
With this import the annotations here are not evaluated by the Python interpreter at module import time, but only when e.g. typing.get_type_hints is called on it (which we do not). The annotations are still understood by tools like MyPy, see here.
This should not be an issue for us, but I can still move this over to the Union syntax if you think this prudent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know this! Definitely keep the pipes!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our policy to date has been to track the PSF's eol policy, not specific LTS. In any case, it is a distinction without a difference in this case: we support 3.8 until it is eol in Oct-2024.

For the older Python versions, I am fine with conditional coding or some level of heroic if the feature is very useful, but in most cases I see, it is just easier to write it in an older-version compatible way and not deal with the mental burden of conditional stuff.

Feel free to plus me in on anything that needs untangling in this vein. Happy to offer advice.

) -> OpHandle:
"""
Emits a `transform.structured.MatchOp`.
Returns a handle to payload ops that match the given names, types, or
interface. If only a single type is given, the value wrapped by the
resulting handle is populated with the respective type.
"""
# Handle interface.
if isinstance(ops, structured.MatchInterfaceEnum) or (
isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
):
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm makes me think we should change the emitted enums to handle strings as well like this (not in this PR obv).

match_op = structured.MatchOp(
transform.AnyOpType.get(),
self,
interface=ops,
)

# Handle op name(s), either given directly as string or given as op.
else:
if isinstance(ops, str):
op_type = transform.OperationType.get(ops)
op_names = [ops]
elif isinstance(ops, Sequence):
op_type = transform.AnyOpType.get()
op_names = [
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
]
else:
op_type = transform.OperationType.get(ops.OPERATION_NAME)
op_names = [ops.OPERATION_NAME]
match_op = structured.MatchOp.match_op_names(
op_type,
self,
op_names,
)

handle = OpHandle(match_op.results_, parent=self)
self.children.append(handle)
return handle


def insert_transform_script(
block_or_insertion_point: ir.Block | ir.InsertionPoint,
script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
"""
Inserts the transform script of the schedule into the module. The script
should accept an instance of OpHandle as argument, which will be called with
the block arg of the newly created named_sequence op.

Example:
This python code
```
module = ir.Module.create()
def test_match_ops_single(module: OpHandle):
module.match_ops(scf.ForOp)
insert_transform_script(module.body, script)
```
generates the following IR:
```
module {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
^bb0(%arg0: !transform.any_op):
%0 = transform.structured.match ops{["scf.for"]} in %arg0
: (!transform.any_op) -> !transform.op<"scf.for">
}
}
```
"""
if isinstance(block_or_insertion_point, ir.Block):
context = block_or_insertion_point.owner.context
insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point)
else:
context = block_or_insertion_point.block.owner.context
insertion_point = block_or_insertion_point

with context, ir.Location.unknown(context):
with insertion_point:
named_sequence_op = transform.NamedSequenceOp(
"__transform_main", [transform.AnyOpType.get()], []
)
with ir.InsertionPoint(named_sequence_op.body):
script(named_sequence_op.bodyTarget)
transform.YieldOp([])

if dump_script:
print(named_sequence_op)
Loading