-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][Transforms][NFC] Turn op/block arg replacements into IRRewrite
s
#81757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Turn op/block arg replacements into IRRewrite
s
#81757
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure). Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into Overview of changes:
Patch is 21.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81757.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b2baa88879b6e9..a07c8a56822de5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,14 +153,12 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
- unsigned numReplacements, unsigned numArgReplacements,
unsigned numRewrites, unsigned numIgnoredOperations,
unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
- numReplacements(numReplacements),
- numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
+ numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -168,12 +166,6 @@ struct RewriterState {
/// The current number of unresolved materializations.
unsigned numUnresolvedMaterializations;
- /// The current number of replacements queued.
- unsigned numReplacements;
-
- /// The current number of argument replacements queued.
- unsigned numArgReplacements;
-
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -184,20 +176,6 @@ struct RewriterState {
unsigned numErased;
};
-//===----------------------------------------------------------------------===//
-// OpReplacement
-
-/// This class represents one requested operation replacement via 'replaceOp' or
-/// 'eraseOp`.
-struct OpReplacement {
- OpReplacement(const TypeConverter *converter = nullptr)
- : converter(converter) {}
-
- /// An optional type converter that can be used to materialize conversions
- /// between the new and old values if necessary.
- const TypeConverter *converter;
-};
-
//===----------------------------------------------------------------------===//
// UnresolvedMaterialization
@@ -318,8 +296,10 @@ class IRRewrite {
MoveBlock,
SplitBlock,
BlockTypeConversion,
+ ReplaceBlockArg,
MoveOperation,
- ModifyOperation
+ ModifyOperation,
+ ReplaceOperation
};
virtual ~IRRewrite() = default;
@@ -330,6 +310,12 @@ class IRRewrite {
/// Commit the rewrite.
virtual void commit() {}
+ /// Cleanup operations. Operations may be unlinked from their blocks during
+ /// the commit/rollback phase, but they must not be erased yet. This is
+ /// because internal dialect conversion state (such as `mapping`) may still
+ /// be using them. Operations must be erased during cleanup.
+ virtual void cleanup() {}
+
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op);
@@ -356,7 +342,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::BlockTypeConversion;
+ rewrite->getKind() <= Kind::ReplaceBlockArg;
}
protected:
@@ -424,6 +410,8 @@ class EraseBlockRewrite : public BlockRewrite {
void commit() override {
// Erase the block.
assert(block && "expected block");
+ assert(block->empty() && "expected empty block");
+ block->dropAllDefinedValueUses();
delete block;
block = nullptr;
}
@@ -585,6 +573,27 @@ class BlockTypeConversionRewrite : public BlockRewrite {
const TypeConverter *converter;
};
+/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but the actual replacement is delayed
+/// until the rewrite is committed.
+class ReplaceBlockArgRewrite : public BlockRewrite {
+public:
+ ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block, BlockArgument arg)
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceBlockArg;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+private:
+ BlockArgument arg;
+};
+
/// An operation rewrite.
class OperationRewrite : public IRRewrite {
public:
@@ -593,7 +602,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::ModifyOperation;
+ rewrite->getKind() <= Kind::ReplaceOperation;
}
protected:
@@ -664,6 +673,41 @@ class ModifyOperationRewrite : public OperationRewrite {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
};
+
+/// Replacing an operation. Erasing an operation is treated as a special case
+/// with "null" replacements. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but values are not replaced and the
+/// original op is not erased until the rewrite is committed.
+class ReplaceOperationRewrite : public OperationRewrite {
+public:
+ ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, const TypeConverter *converter,
+ bool changedResults)
+ : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
+ converter(converter), changedResults(changedResults) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceOperation;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+ void cleanup() override;
+
+private:
+ friend struct OperationConverter;
+
+ /// An optional type converter that can be used to materialize conversions
+ /// between the new and old values if necessary.
+ const TypeConverter *converter;
+
+ /// A vector of indices into `replacements` of operations that were replaced
+ /// with values with different result types than the original operation, e.g.
+ /// 1->N conversion of some kind.
+ bool changedResults;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -856,6 +900,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void eraseBlock(Block *block) override {
if (erased.contains(block))
return;
+ assert(block->empty() && "expected empty block");
block->dropAllDefinedValueUses();
RewriterBase::eraseBlock(block);
}
@@ -887,12 +932,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// conversion.
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
- /// Ordered map of requested operation replacements.
- llvm::MapVector<Operation *, OpReplacement> replacements;
-
- /// Ordered vector of any requested block argument replacements.
- SmallVector<BlockArgument, 4> argReplacements;
-
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -907,11 +946,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A vector of indices into `replacements` of operations that were replaced
- /// with values with different result types than the original operation, e.g.
- /// 1->N conversion of some kind.
- SmallVector<unsigned, 4> operationsWithChangedResults;
-
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -923,6 +957,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;
+ DenseSet<Operation *> *trackedOps = nullptr;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -969,6 +1005,8 @@ void BlockTypeConversionRewrite::commit() {
}
}
+ assert(origBlock->empty() && "expected empty block");
+ origBlock->dropAllDefinedValueUses();
delete origBlock;
origBlock = nullptr;
}
@@ -1031,6 +1069,47 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
return success();
}
+void ReplaceBlockArgRewrite::commit() {
+ Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
+ if (!repl)
+ return;
+
+ if (isa<BlockArgument>(repl)) {
+ arg.replaceAllUsesWith(repl);
+ return;
+ }
+
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = cast<OpResult>(repl).getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+}
+
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+
+void ReplaceOperationRewrite::commit() {
+ for (OpResult result : op->getResults())
+ if (Value newValue =
+ rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+ result.replaceAllUsesWith(newValue);
+ if (rewriterImpl.trackedOps)
+ rewriterImpl.trackedOps->erase(op);
+ // Do not erase the operation yet. It may still be referenced in `mapping`.
+ op->getBlock()->getOperations().remove(op);
+}
+
+void ReplaceOperationRewrite::rollback() {
+ for (auto result : op->getResults())
+ rewriterImpl.mapping.erase(result);
+}
+
+void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks()) {
@@ -1053,51 +1132,16 @@ void ConversionPatternRewriterImpl::discardRewrites() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Apply all of the rewrites replacements requested during conversion.
- for (auto &repl : replacements) {
- for (OpResult result : repl.first->getResults())
- if (Value newValue = mapping.lookupOrNull(result, result.getType()))
- result.replaceAllUsesWith(newValue);
- }
-
- // Apply all of the requested argument replacements.
- for (BlockArgument arg : argReplacements) {
- Value repl = mapping.lookupOrNull(arg, arg.getType());
- if (!repl)
- continue;
-
- if (isa<BlockArgument>(repl)) {
- arg.replaceAllUsesWith(repl);
- continue;
- }
-
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
- Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
- Operation *user = operand.getOwner();
- return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- });
- }
+ // Commit all rewrites.
+ for (auto &rewrite : rewrites)
+ rewrite->commit();
+ for (auto &rewrite : rewrites)
+ rewrite->cleanup();
// Drop all of the unresolved materialization operations created during
// conversion.
for (auto &mat : unresolvedMaterializations)
eraseRewriter.eraseOp(mat.getOp());
-
- // In a second pass, erase all of the replaced operations in reverse. This
- // allows processing nested operations before their parent region is
- // destroyed. Because we process in reverse order, producers may be deleted
- // before their users (a pattern deleting a producer and then the consumer)
- // so we first drop all uses explicitly.
- for (auto &repl : llvm::reverse(replacements))
- eraseRewriter.eraseOp(repl.first);
-
- // Commit all rewrites.
- for (auto &rewrite : rewrites)
- rewrite->commit();
}
//===----------------------------------------------------------------------===//
@@ -1105,28 +1149,14 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
- replacements.size(), argReplacements.size(),
rewrites.size(), ignoredOps.size(),
eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any replaced arguments.
- for (BlockArgument replacedArg :
- llvm::drop_begin(argReplacements, state.numArgReplacements))
- mapping.erase(replacedArg);
- argReplacements.resize(state.numArgReplacements);
-
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Reset any replaced operations and undo any saved mappings.
- for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
- for (auto result : repl.first->getResults())
- mapping.erase(result);
- while (replacements.size() != state.numReplacements)
- replacements.pop_back();
-
// Pop all of the newly inserted materializations.
while (unresolvedMaterializations.size() !=
state.numUnresolvedMaterializations) {
@@ -1151,11 +1181,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
- // Reset operations with changed results.
- while (!operationsWithChangedResults.empty() &&
- operationsWithChangedResults.back() >= state.numReplacements)
- operationsWithChangedResults.pop_back();
-
while (eraseRewriter.erased.size() != state.numErased)
eraseRewriter.erased.pop_back();
}
@@ -1224,7 +1249,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation was replaced or its parent ignored.
- return replacements.count(op) || ignoredOps.count(op->getParentOp());
+ return ignoredOps.count(op->getParentOp()) ||
+ llvm::any_of(rewrites, [&](auto &rewrite) {
+ auto *opReplacement =
+ dyn_cast<ReplaceOperationRewrite>(rewrite.get());
+ if (!opReplacement)
+ return false;
+ return opReplacement->getOperation() == op;
+ });
}
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1374,7 +1406,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1408,7 +1440,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
@@ -1440,7 +1472,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!replacements.count(op) && "operation was already replaced");
+#ifndef NDEBUG
+ for (auto &rewrite : rewrites)
+ if (auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewrite.get()))
+ assert(opReplacement->getOperation() != op &&
+ "operation was already replaced");
+#endif // NDEBUG
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1455,11 +1492,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
mapping.map(result, newValue);
resultChanged |= (newValue.getType() != result.getType());
}
- if (resultChanged)
- operationsWithChangedResults.push_back(replacements.size());
- // Record the requested operation replacement.
- replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
+ appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
+ resultChanged);
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
@@ -1554,8 +1589,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- impl->notifyBlockIsBeingErased(block);
-
// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);
@@ -1564,6 +1597,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
// object and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
+ impl->notifyBlockIsBeingErased(block);
block->getParent()->getBlocks().remove(block);
}
@@ -1593,7 +1627,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
- impl->argReplacements.push_back(from);
+ impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
@@ -2015,16 +2049,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#ifndef NDEBUG
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-
// Check that the root was either replaced or updated in place.
+ auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
auto replacedRoot = [&] {
- return llvm::any_of(
- llvm::drop_begin(impl.replacements, curState.numReplacements),
- [op](auto &it) { return it.first == op; });
+ return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
};
auto updatedRootInPlace = [&] {
- return hasRewrite<ModifyOperationRewrite>(
- llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
+ return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
assert((replacedRoot() || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
@@ -2057,7 +2088,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
if (!rewrite)
continue;
Block *block = rewrite->getBlock();
- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
+ if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
+ ReplaceBlockArgRewrite>(rewrite))
continue;
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
@@ -2452,6 +2484,7 @@ LogicalResult OperationConverter::convertOperations(
ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
rewriterImpl.notifyCallback = notifyCallback;
+ rewriterImpl.trackedOps = trackedOps;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
@@ -2469,13 +2502,6 @@ LogicalResult OperationConverter::convertOperations(
rewriterImpl.discardRewrites();
} else {
rewriterImpl.applyRewrites();
-
- // It is possible for a later pattern to erase an op that was originally
- // identified as illegal and added to the trackedOps, remove it now after
- // replacements have been computed.
- if (trackedOps)
- for (auto &repl : rewriterImpl.replacements)
- trackedOps->erase(repl.first);
}
return success();
}
@@ -2489,21 +2515,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
- if (rewriterImpl.operationsWithChangedResults.empty())
- return success();
-
// Process requ...
[truncated]
|
ae08e91
to
dcd13b8
Compare
8405efc
to
613a616
Compare
dcd13b8
to
61e82f6
Compare
613a616
to
b8d4cbd
Compare
When a `ModifyOperationRewrite` is committed, the operation may already have been erased, so `OperationName` must be cached in the rewrite object. Note: This will no longer be needed with #81757, which adds a "cleanup" method to `IRRewrite`.
When a `ModifyOperationRewrite` is committed, the operation may already have been erased, so `OperationName` must be cached in the rewrite object. Note: This will no longer be needed with #81757, which adds a "cleanup" method to `IRRewrite`.
68cb259
to
04eb7bd
Compare
b8d4cbd
to
886f558
Compare
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into `IRRewrite`s, so that they can be committed or rolled back just like any other rewrite. This simplifies the internal state of the dialect conversion. Overview of changes: * Add two new rewrite classes: `ReplaceBlockArgRewrite` and `ReplaceOperationRewrite`. Remove the `OpReplacement` helper class; it is now part of `ReplaceOperationRewrite`. * Simplify `RewriterState`: `numReplacements` and `numArgReplacements` are no longer needed. (Now being kept track of by `numRewrites`.) * Add `IRRewrite::cleanup`. Operations should not be erased in `commit` because they may still be referenced in other internal state of the dialect conversion (`mapping`). Detaching operations is fine.
886f558
to
6f7d3e7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks
@@ -1462,7 +1490,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( | |||
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, | |||
ValueRange newValues) { | |||
assert(newValues.size() == op->getNumResults()); | |||
assert(!replacements.count(op) && "operation was already replaced"); | |||
#ifndef NDEBUG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this check present before this PR? I have multiple Flang end-to-end tests spending very long time in compilation. I tried to find the guilty PR, but I was using debug compiler builds. This PR increases this test compilation from 26 seconds to 137 seconds:
Character(1),Parameter :: c717(2,3,4,5,6,7,8) = Reshape([('a',i=1,Size(c717))], Shape(c717))
End
I will have to confirm if this check is causing it, but you may know it right away. If it is indeed an expensive check, should it be only enabled under LLVM_ENABLE_EXPENSIVE_CHECKS
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, false alarm. Disabling this only brings the time to 107 seconds. I will look further, and also try with the release compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I reproduce this locally? Debug builds should not make the compile time that much slower. But you are right, this is likely a performance regression of this change or one of the other dialect conversion changes that I submitted recently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm that this commit slows down the above FIR to MLIR LLVM IR dialect conversion pass for the Fortran program that Slava gave (*) by 3.5x in release mode too, and that #81759 adds another 1.5x slowdown on that (from 1.8s to 8.8s on my machine in release mode in with the two patch).
Using perf, it seems the slowdown is caused by changes in mlir::detail::ConversionPatternRewriterImpl::isOpIgnored
(probably related to the data structure changes).
After the two patches in release mode:
Samples: 158K of event 'cycles', Event count (approx.): 21633186643
Children Self Command Shared Object Symbol
+ 65.37% 64.59% fir-opt fir-opt [.] mlir::detail::ConversionPatternRewriterImpl::isOpIgnored
+ 65.22% 0.00% fir-opt [unknown] [.] 0x0000000000000001
+ 65.18% 0.00% fir-opt [unknown] [.] 0x480029bc01058d48
+ 65.18% 0.00% fir-opt fir-opt [.] mlir::detail::ConversionPatternRewriterImpl::~ConversionPatternRewriterImpl
+ 24.25% 0.00% fir-opt [unknown] [.] 0x0000000100000001
+ 24.19% 23.65% fir-opt fir-opt [.] mlir::detail::ConversionPatternRewriterImpl::notifyOpReplaced
+ 24.05% 0.00% fir-opt [unknown] [.] 0x26ee058d48fb8948
+ 24.05% 0.00% fir-opt fir-opt [.] mlir::RegisteredOperationName::Model<fir::InsertValueOp>::~Model
+ 24.05% 0.00% fir-opt [unknown] [.] 0x000055a74703d488
+ 0.82% 0.00% fir-opt [unknown] [k] 0000000000000000
0.78% 0.47% fir-opt fir-opt [.] mlir::Lexer::lexBareIdentifierOrKeyword
....
Before the patches, mlir::detail::ConversionPatternRewriterImpl::isOpIgnored
was nowhere to see in the perf report (MLIR IO dominated the run).
You can reproduce if you have flang builds enabled, and with repro.f90 that is the Fortran source from Slava above with:
# First phase of compilation not impacted by patch
bin/bbc -emit-fir repro.f90 -o - | bin/fir-opt --cg-rewrite -o -input.fir
# FIR to LLVM dialect conversion pass impacted by pass
time bin/fir-opt --fir-to-llvm-ir -input.fir -o output.mlir
I will try to see if I can come with a pure MLIR reproducer.
(*) about the Fortran program: this program generates a global that is a 7d array of 40320 chars with an initial value.
So far, flang generates a chain of insert_value for character types, so the operation that is impacted by the slow-down is a fir.global where the body contains a chain of 40320 fir.insert_value + 3 other ops (the value being inserted, and the terminator). We are planning to move away from this and use attribute for global initializer as much as possible. However, the slow down will likely kick in for every functions with more than a few thousands ops, and this is quite easily reached with big Fortran programs. Global constants are just an easy way to create a lot of IR with a few lines of Fortran to reproduce the issue.
fir.global internal @_QFECc717 constant : !fir.array<2x3x4x5x6x7x8x!fir.char<1>> {
%0 = fir.undefined !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
%1 = fir.string_lit "a"(1) : !fir.char<1>
%2 = fir.insert_value %0, %1, [0 : index, 0 : index, 0 : index, 0 : index, 0 : index, 0 : index, 0 : index] : (!fir.array<2x3x4x5x6x7x8x!fir.char<1>>, !fir.char<1>) -> !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
// ....
%40321 = fir.insert_value %40320, %1, [1 : index, 2 : index, 3 : index, 4 : index, 5 : index, 6 : index, 7 : index] : (!fir.array<2x3x4x5x6x7x8x!fir.char<1>>, !fir.char<1>) -> !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
fir.has_value %40321 : !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
}
This is being rewritten to very similar LLVM dialect IR:
llvm.mlir.global internal constant @_QFECc717() {addr_space = 0 : i32} : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>> {
%0 = llvm.mlir.undef : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
%1 = llvm.mlir.constant("a") : !llvm.array<1 x i8>
%2 = llvm.insertvalue %1, %0[0, 0, 0, 0, 0, 0, 0] : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
// ...
%40321 = llvm.insertvalue %1, %40320[7, 6, 5, 4, 3, 2, 1] : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
llvm.return %40321 : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I traced the issue back to the same function (isOpIgnored
) on some other tests.
Thx for the detailed instructions. Looks like it fixes the test case:
time build/bin/fir-opt --fir-to-llvm-ir input.fir -o output.mlir
build/bin/fir-opt --fir-to-llvm-ir input.fir -o output.mlir 0.73s user 0.08s system 100% cpu 0.816 total
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still waiting for a CI run on another project that had a regression, to make sure that the issue is fixed there as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix! I confirm this fixes all the compilation time slowdown on the Fortran program.
The dialect conversion does not directly erase ops that are replaced/erased with a rewriter. Instead, the op stays in place and is erased at the end if the dialect conversion succeeds. However, ops that were replaced/erased are ignored from that point on. #81757 introduced a compile time regression that made the check whether an op is ignored or not more expensive. Whether an op is ignored or not is queried many times throughout a dialect conversion, so the check must be fast. After this change, replaced ops are stored in the `ignoredOps` set. This also simplifies the dialect conversion a bit.
…83023) The dialect conversion does not directly erase ops that are replaced/erased with a rewriter. Instead, the op stays in place and is erased at the end if the dialect conversion succeeds. However, ops that were replaced/erased are ignored from that point on. #81757 introduced a compile time regression that made the check whether an op is ignored or not more expensive. Whether an op is ignored or not is queried many times throughout a dialect conversion, so the check must be fast. After this change, replaced ops are stored in the `ignoredOps` set. This also simplifies the dialect conversion a bit.
…lvm#83023) The dialect conversion does not directly erase ops that are replaced/erased with a rewriter. Instead, the op stays in place and is erased at the end if the dialect conversion succeeds. However, ops that were replaced/erased are ignored from that point on. llvm#81757 introduced a compile time regression that made the check whether an op is ignored or not more expensive. Whether an op is ignored or not is queried many times throughout a dialect conversion, so the check must be fast. After this change, replaced ops are stored in the `ignoredOps` set. This also simplifies the dialect conversion a bit.
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).
Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into
IRRewrite
s, so that they can be committed or rolled back just like any other rewrite. This simplifies the internal state of the dialect conversion.Overview of changes:
ReplaceBlockArgRewrite
andReplaceOperationRewrite
. Remove theOpReplacement
helper class; it is now part ofReplaceOperationRewrite
.RewriterState
:numReplacements
andnumArgReplacements
are no longer needed. (Now being kept track of bynumRewrites
.)IRRewrite::cleanup
. Operations should not be erased incommit
because they may still be referenced in other internal state of the dialect conversion (mapping
). Detaching operations is fine.trackedOps
are now updated during the "commit" phase instead of after applying all rewrites.