Skip to content

Commit ef8c26b

Browse files
[mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (#68330)
1 parent 7e77f19 commit ef8c26b

File tree

9 files changed

+577
-289
lines changed

9 files changed

+577
-289
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

+27-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {
2020

2121
let hasOperationAttrVerify = 1;
2222
let extraClassDeclaration = [{
23+
/// Symbol name for the default entry point "named sequence".
24+
constexpr const static ::llvm::StringLiteral
25+
kTransformEntryPointSymbolName = "__transform_main";
26+
2327
/// Name of the attribute attachable to the symbol table operation
2428
/// containing named sequences. This is used to trigger verification.
25-
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
26-
"transform.with_named_sequence";
29+
constexpr const static ::llvm::StringLiteral
30+
kWithNamedSequenceAttrName = "transform.with_named_sequence";
2731

2832
/// Name of the attribute attachable to an operation so it can be
2933
/// identified as root by the default interpreter pass.
@@ -74,6 +78,22 @@ def Transform_Dialect : Dialect {
7478
using ExtensionTypePrintingHook =
7579
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
7680

81+
/// Appends the given module as a transform symbol library available to
82+
/// all dialect users.
83+
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
84+
library) {
85+
libraryModules.push_back(std::move(library));
86+
}
87+
88+
/// Returns a range of registered library modules.
89+
auto getLibraryModules() const {
90+
return ::llvm::map_range(
91+
libraryModules,
92+
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
93+
return library.get();
94+
});
95+
}
96+
7797
private:
7898
/// Registers operations specified as template parameters with this
7999
/// dialect. Checks that they implement the required interfaces.
@@ -132,6 +152,11 @@ def Transform_Dialect : Dialect {
132152
/// lookups when the type is fully constructed.
133153
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
134154
typePrintingHooks;
155+
156+
/// Modules containing symbols, e.g. named sequences, that will be
157+
/// resolved by the interpreter when used.
158+
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
159+
libraryModules;
135160
}];
136161
}
137162

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class TransformOptions {
111111
LogicalResult
112112
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
113113
const RaggedArray<MappedValue> &extraMapping = {},
114-
const TransformOptions &options = TransformOptions());
114+
const TransformOptions &options = TransformOptions(),
115+
bool enforceToplevelTransformOp = true);
115116

116117
/// The state maintained across applications of various ops implementing the
117118
/// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,7 @@ class TransformState {
193194

194195
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
195196
const RaggedArray<MappedValue> &,
196-
const TransformOptions &);
197+
const TransformOptions &, bool);
197198

198199
friend TransformState
199200
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
10+
#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
13+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Support/LLVM.h"
16+
#include <memory>
17+
18+
namespace mlir {
19+
struct LogicalResult;
20+
class MLIRContext;
21+
class ModuleOp;
22+
class Operation;
23+
template <typename>
24+
class OwningOpRef;
25+
class Region;
26+
27+
namespace transform {
28+
namespace detail {
29+
/// Utility to parse and verify the content of a `transformFileName` MLIR file
30+
/// containing a transform dialect specification.
31+
LogicalResult
32+
parseTransformModuleFromFile(MLIRContext *context,
33+
llvm::StringRef transformFileName,
34+
OwningOpRef<ModuleOp> &transformModule);
35+
36+
/// Utility to load a transform interpreter `module` from a module that has
37+
/// already been preloaded in the context.
38+
/// This mode is useful in cases where explicit parsing of a transform library
39+
/// from file is expected to be prohibitively expensive.
40+
/// In such cases, the transform module is expected to be found in the preloaded
41+
/// library modules of the transform dialect.
42+
/// Returns null if the module is not found.
43+
ModuleOp getPreloadedTransformModule(MLIRContext *context);
44+
45+
/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
46+
/// that is either:
47+
/// 1. nested under `root` (takes precedence).
48+
/// 2. nested under `module`, if not found in `root`.
49+
/// Reports errors and returns null if no such operation found.
50+
TransformOpInterface findTransformEntryPoint(
51+
Operation *root, ModuleOp module,
52+
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
53+
54+
/// Merge all symbols from `other` into `target`. Both ops need to implement the
55+
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
56+
/// modified by this function and might not verify after the function returns.
57+
/// Upon merging, private symbols may be renamed in order to avoid collisions in
58+
/// the result. Public symbols may not collide, with the exception of
59+
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
60+
/// one of the two is external, in which case the other op preserved (or any one
61+
/// of the two if both are external).
62+
// TODO: Reconsider cloning individual ops rather than forcing users of the
63+
// function to clone (or move) `other` in order to improve efficiency.
64+
// This might primarily make sense if we can also prune the symbols that
65+
// are merged to a subset (such as those that are actually used).
66+
LogicalResult mergeSymbolsInto(Operation *target,
67+
OwningOpRef<Operation *> other);
68+
} // namespace detail
69+
70+
/// Standalone util to apply the named sequence `entryPoint` to the payload.
71+
/// This is done in 3 steps:
72+
/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
73+
/// calling detail::findTransformEntryPoint.
74+
/// 2. if the entry point is found and not nested under
75+
/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
76+
/// the `sharedTransformModule`. Note: this may modify the transform IR
77+
/// embedded with the payload IR.
78+
/// 3. apply the transform IR to the payload IR, relaxing the requirement that
79+
/// the transform IR is a top-level transform op. We are applying a named
80+
/// sequence anyway.
81+
LogicalResult applyTransformNamedSequence(
82+
Operation *payload, ModuleOp transformModule,
83+
const TransformOptions &options,
84+
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
85+
86+
} // namespace transform
87+
} // namespace mlir
88+
89+
#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

+13-13
Original file line numberDiff line numberDiff line change
@@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
20792079
// Entry point.
20802080
//===----------------------------------------------------------------------===//
20812081

2082-
LogicalResult
2083-
transform::applyTransforms(Operation *payloadRoot,
2084-
TransformOpInterface transform,
2085-
const RaggedArray<MappedValue> &extraMapping,
2086-
const TransformOptions &options) {
2087-
#ifndef NDEBUG
2088-
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2089-
transform->getNumOperands() != 0) {
2090-
transform->emitError()
2091-
<< "expected transform to start at the top-level transform op";
2092-
llvm::report_fatal_error("could not run transforms",
2093-
/*gen_crash_diag=*/false);
2082+
LogicalResult transform::applyTransforms(
2083+
Operation *payloadRoot, TransformOpInterface transform,
2084+
const RaggedArray<MappedValue> &extraMapping,
2085+
const TransformOptions &options, bool enforceToplevelTransformOp) {
2086+
if (enforceToplevelTransformOp) {
2087+
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2088+
transform->getNumOperands() != 0) {
2089+
return transform->emitError()
2090+
<< "expected transform to start at the top-level transform op";
2091+
}
2092+
} else if (failed(
2093+
detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
2094+
return failure();
20942095
}
2095-
#endif // NDEBUG
20962096

20972097
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
20982098
options);

mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
22
CheckUses.cpp
33
InferEffects.cpp
44
TransformInterpreterPassBase.cpp
5+
TransformInterpreterUtils.cpp
56

67
DEPENDS
78
MLIRTransformDialectTransformsIncGen

0 commit comments

Comments
 (0)