Skip to content

Commit dc43c03

Browse files
committed
[mlir] Enable decoupling two kinds of greedy behavior.
The greedy rewriter is used in many different flows and it has a lot of convenience (work list management, debugging actions, tracing, etc). But it combines two kinds of greedy behavior 1) wrt how ops are matched, 2) folding wherever it can. These are independent forms of greedy and leads to inefficiency. E.g., cases where one need to create different phases in lowering, one is required to applying patterns in specific order/different passes. But if using the driver one ends up needlessly retrying folding or having multiple rounds of folding attempts, where one final run would have sufficed. It also is rather confusing to users that just want to apply some patterns while having all the convenience and structure to have unrelated changes to IR. Of course folks can locally avoid this behavior by just building their own, but this is also a common requested feature that folks keep on working around locally in suboptimal ways.
1 parent 953b07f commit dc43c03

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ class GreedyRewriteConfig {
9191

9292
/// An optional listener that should be notified about IR modifications.
9393
RewriterBase::Listener *listener = nullptr;
94+
95+
// Whether this should fold while greedily rewriting.
96+
//
97+
// Note: greedy here generally refers to two forms, 1) greedily applying
98+
// patterns based purely on benefit and applying without backtracking using
99+
// default cost model, 2) greedily folding where possible while attempting to
100+
// match and rewrite using the provided patterns. With this option set to
101+
// false it only does the former.
102+
bool fold = true;
94103
};
95104

96105
//===----------------------------------------------------------------------===//
@@ -104,8 +113,8 @@ class GreedyRewriteConfig {
104113
/// The greedy rewrite may prematurely stop after a maximum number of
105114
/// iterations, which can be configured in the configuration parameter.
106115
///
107-
/// Also performs folding and simple dead-code elimination before attempting to
108-
/// match any of the provided patterns.
116+
/// Also performs simple dead-code elimination before attempting to match any of
117+
/// the provided patterns.
109118
///
110119
/// A region scope can be set in the configuration parameter. By default, the
111120
/// scope is set to the specified region. Only in-scope ops are added to the
@@ -117,10 +126,18 @@ class GreedyRewriteConfig {
117126
///
118127
/// Note: This method does not apply patterns to the region's parent operation.
119128
LogicalResult
129+
applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
130+
GreedyRewriteConfig config = GreedyRewriteConfig(),
131+
bool *changed = nullptr);
132+
/// Same as `applyPatternsAndGreedily` above with folding.
133+
inline LogicalResult
120134
applyPatternsAndFoldGreedily(Region &region,
121135
const FrozenRewritePatternSet &patterns,
122136
GreedyRewriteConfig config = GreedyRewriteConfig(),
123-
bool *changed = nullptr);
137+
bool *changed = nullptr) {
138+
config.fold = true;
139+
return applyPatternsGreedily(region, patterns, config, changed);
140+
}
124141

125142
/// Rewrite ops nested under the given operation, which must be isolated from
126143
/// above, by repeatedly applying the highest benefit patterns in a greedy
@@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region &region,
129146
/// The greedy rewrite may prematurely stop after a maximum number of
130147
/// iterations, which can be configured in the configuration parameter.
131148
///
132-
/// Also performs folding and simple dead-code elimination before attempting to
133-
/// match any of the provided patterns.
149+
/// Also performs simple dead-code elimination before attempting to match any of
150+
/// the provided patterns.
134151
///
135152
/// This overload runs a separate greedy rewrite for each region of the
136153
/// specified op. A region scope can be set in the configuration parameter. By
@@ -147,10 +164,9 @@ applyPatternsAndFoldGreedily(Region &region,
147164
///
148165
/// Note: This method does not apply patterns to the given operation itself.
149166
inline LogicalResult
150-
applyPatternsAndFoldGreedily(Operation *op,
151-
const FrozenRewritePatternSet &patterns,
152-
GreedyRewriteConfig config = GreedyRewriteConfig(),
153-
bool *changed = nullptr) {
167+
applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
168+
GreedyRewriteConfig config = GreedyRewriteConfig(),
169+
bool *changed = nullptr) {
154170
bool anyRegionChanged = false;
155171
bool failed = false;
156172
for (Region &region : op->getRegions()) {
@@ -164,15 +180,24 @@ applyPatternsAndFoldGreedily(Operation *op,
164180
*changed = anyRegionChanged;
165181
return failure(failed);
166182
}
183+
/// Same as `applyPatternsGreedily` above with folding.
184+
inline LogicalResult
185+
applyPatternsAndFoldGreedily(Operation *op,
186+
const FrozenRewritePatternSet &patterns,
187+
GreedyRewriteConfig config = GreedyRewriteConfig(),
188+
bool *changed = nullptr) {
189+
config.fold = true;
190+
return applyPatternsGreedily(op, patterns, config, changed);
191+
}
167192

168193
/// Rewrite the specified ops by repeatedly applying the highest benefit
169194
/// patterns in a greedy worklist driven manner until a fixpoint is reached.
170195
///
171196
/// The greedy rewrite may prematurely stop after a maximum number of
172197
/// iterations, which can be configured in the configuration parameter.
173198
///
174-
/// Also performs folding and simple dead-code elimination before attempting to
175-
/// match any of the provided patterns.
199+
/// Also performs simple dead-code elimination before attempting to match any of
200+
/// the provided patterns.
176201
///
177202
/// Newly created ops and other pre-existing ops that use results of rewritten
178203
/// ops or supply operands to such ops are also processed, unless such ops are
@@ -194,10 +219,19 @@ applyPatternsAndFoldGreedily(Operation *op,
194219
/// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
195220
/// `ops` were erased.
196221
LogicalResult
222+
applyOpPatternsGreedily(ArrayRef<Operation *> ops,
223+
const FrozenRewritePatternSet &patterns,
224+
GreedyRewriteConfig config = GreedyRewriteConfig(),
225+
bool *changed = nullptr, bool *allErased = nullptr);
226+
/// Same as `applyOpPatternsGreedily` with folding.
227+
inline LogicalResult
197228
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
198229
const FrozenRewritePatternSet &patterns,
199230
GreedyRewriteConfig config = GreedyRewriteConfig(),
200-
bool *changed = nullptr, bool *allErased = nullptr);
231+
bool *changed = nullptr, bool *allErased = nullptr) {
232+
config.fold = true;
233+
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
234+
}
201235

202236
} // namespace mlir
203237

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements mlir::applyPatternsAndFoldGreedily.
9+
// This file implements mlir::applyPatternsGreedily.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

@@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
488488
// infinite folding loop, as every constant op would be folded to an
489489
// Attribute and then immediately be rematerialized as a constant op, which
490490
// is then put on the worklist.
491-
if (!op->hasTrait<OpTrait::ConstantLike>()) {
491+
if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
492492
SmallVector<OpFoldResult> foldResults;
493493
if (succeeded(op->fold(foldResults))) {
494494
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -840,6 +840,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
840840
// regions to enable more aggressive CSE'ing).
841841
OperationFolder folder(ctx, this);
842842
auto insertKnownConstant = [&](Operation *op) {
843+
// This hoisting is to enable more folding, so skip checking if known
844+
// constant, updating dense map etc if not doing folding.
845+
if (!config.fold)
846+
return false;
847+
843848
// Check for existing constants when populating the worklist. This avoids
844849
// accidentally reversing the constant order during processing.
845850
Attribute constValue;
@@ -894,9 +899,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
894899
}
895900

896901
LogicalResult
897-
mlir::applyPatternsAndFoldGreedily(Region &region,
898-
const FrozenRewritePatternSet &patterns,
899-
GreedyRewriteConfig config, bool *changed) {
902+
mlir::applyPatternsGreedily(Region &region,
903+
const FrozenRewritePatternSet &patterns,
904+
GreedyRewriteConfig config, bool *changed) {
900905
// The top-level operation must be known to be isolated from above to
901906
// prevent performing canonicalizations on operations defined at or above
902907
// the region containing 'op'.
@@ -1012,7 +1017,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
10121017
return region;
10131018
}
10141019

1015-
LogicalResult mlir::applyOpPatternsAndFold(
1020+
LogicalResult mlir::applyOpPatternsGreedily(
10161021
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
10171022
GreedyRewriteConfig config, bool *changed, bool *allErased) {
10181023
if (ops.empty()) {

0 commit comments

Comments
 (0)