-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][Transforms][NFC] Turn in-place op modification into IRRewrite
#81245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Turn in-place op modification into IRRewrite
#81245
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed. Full diff: https://github.com/llvm/llvm-project/pull/81245.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ffdb069f6e9b8..d0114a148cd37 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,15 +154,13 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewriteActions, unsigned numIgnoredOperations,
- unsigned numRootUpdates)
+ unsigned numRewriteActions, unsigned numIgnoredOperations)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements),
numRewriteActions(numRewriteActions),
- numIgnoredOperations(numIgnoredOperations),
- numRootUpdates(numRootUpdates) {}
+ numIgnoredOperations(numIgnoredOperations) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -181,44 +179,6 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
-
- /// The current number of operations that were updated in place.
- unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
- OperationTransactionState() = default;
- OperationTransactionState(Operation *op)
- : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
- operands(op->operand_begin(), op->operand_end()),
- successors(op->successor_begin(), op->successor_end()) {}
-
- /// Discard the transaction state and reset the state of the original
- /// operation.
- void resetOperation() const {
- op->setLoc(loc);
- op->setAttrs(attrs);
- op->setOperands(operands);
- for (const auto &it : llvm::enumerate(successors))
- op->setSuccessor(it.value(), it.index());
- }
-
- /// Return the original operation of this state.
- Operation *getOperation() const { return op; }
-
-private:
- Operation *op;
- LocationAttr loc;
- DictionaryAttr attrs;
- SmallVector<Value, 8> operands;
- SmallVector<Block *, 2> successors;
};
//===----------------------------------------------------------------------===//
@@ -758,7 +718,8 @@ class RewriteAction {
MoveBlock,
SplitBlock,
BlockTypeConversion,
- MoveOperation
+ MoveOperation,
+ ModifyOperation
};
virtual ~RewriteAction() = default;
@@ -980,7 +941,7 @@ class OperationAction : public RewriteAction {
static bool classof(const RewriteAction *action) {
return action->getKind() >= Kind::MoveOperation &&
- action->getKind() <= Kind::MoveOperation;
+ action->getKind() <= Kind::ModifyOperation;
}
protected:
@@ -1019,6 +980,34 @@ class MoveOperationAction : public OperationAction {
// this operation was the only operation in the region.
Operation *insertBeforeOp;
};
+
+/// Rewrite action that represents the in-place modification of an operation.
+/// The previous state of the operation is stored in this action.
+class ModifyOperationAction : public OperationAction {
+public:
+ ModifyOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : OperationAction(Kind::ModifyOperation, rewriterImpl, op),
+ loc(op->getLoc()), attrs(op->getAttrDictionary()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void rollback() override {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (const auto &it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+private:
+ LocationAttr loc;
+ DictionaryAttr attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1172,9 +1161,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A transaction state for each of operations that were updated in-place.
- SmallVector<OperationTransactionState, 4> rootUpdates;
-
/// A vector of indices into `replacements` of operations that were replaced
/// with values with different result types than the original operation, e.g.
/// 1->N conversion of some kind.
@@ -1226,10 +1212,6 @@ static void detachNestedAndErase(Operation *op) {
}
void ConversionPatternRewriterImpl::discardRewrites() {
- // Reset any operations that were updated in place.
- for (auto &state : rootUpdates)
- state.resetOperation();
-
undoRewriteActions();
// Remove any newly created ops.
@@ -1304,16 +1286,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
replacements.size(), argReplacements.size(),
- rewriteActions.size(), ignoredOps.size(),
- rootUpdates.size());
+ rewriteActions.size(), ignoredOps.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any operations that were updated in place.
- for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
- rootUpdates[i].resetOperation();
- rootUpdates.resize(state.numRootUpdates);
-
// Reset any replaced arguments.
for (BlockArgument replacedArg :
llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1740,7 +1716,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
- impl->rootUpdates.emplace_back(op);
+ impl->appendRewriteAction<ModifyOperationAction>(op);
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1759,13 +1735,17 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
"operation did not have a pending in-place update");
#endif
// Erase the last update for this operation.
- auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
- auto &rootUpdates = impl->rootUpdates;
- auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
- assert(it != rootUpdates.rend() && "no root update started on op");
- (*it).resetOperation();
- int updateIdx = std::prev(rootUpdates.rend()) - it;
- rootUpdates.erase(rootUpdates.begin() + updateIdx);
+ auto it =
+ llvm::find_if(llvm::reverse(impl->rewriteActions),
+ [&](std::unique_ptr<RewriteAction> &action) {
+ auto *modifyAction =
+ dynamic_cast<ModifyOperationAction *>(action.get());
+ return modifyAction && modifyAction->getOperation() == op;
+ });
+ assert(it != impl->rewriteActions.rend() && "no root update started on op");
+ (*it)->rollback();
+ int updateIdx = std::prev(impl->rewriteActions.rend()) - it;
+ impl->rewriteActions.erase(impl->rewriteActions.begin() + updateIdx);
}
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2118,8 +2098,11 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
};
auto updatedRootInPlace = [&] {
return llvm::any_of(
- llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
- [op](auto &state) { return state.getOperation() == op; });
+ llvm::drop_begin(impl.rewriteActions, curState.numRewriteActions),
+ [op](auto &action) {
+ auto *modifyAction = dyn_cast<ModifyOperationAction>(action.get());
+ return modifyAction && modifyAction->getOperation() == op;
+ });
};
(void)replacedRoot;
(void)updatedRootInPlace;
@@ -2213,8 +2196,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
- for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
- Operation *op = impl.rootUpdates[i].getOperation();
+ for (int i = state.numRewriteActions, e = newState.numRewriteActions; i != e;
+ ++i) {
+ auto *action =
+ dyn_cast<ModifyOperationAction>(impl.rewriteActions[i].get());
+ if (!action)
+ continue;
+ Operation *op = action->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "failed to legalize operation updated in-place '{0}'",
|
7503c0c
to
ebfaca6
Compare
cdbd927
to
f7010ea
Compare
RewriteAction
sIRRewrite
ebfaca6
to
1d17c76
Compare
f7010ea
to
4bb6521
Compare
1d17c76
to
5e261de
Compare
820bcdd
to
1c69f42
Compare
|
||
private: | ||
LocationAttr loc; | ||
DictionaryAttr attrs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is properties needed too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it a separate PR (#82474), so that this PR can stay NFC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG for keeping this a pure refactor
1c69f42
to
8a8a79d
Compare
…ction`s This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
8a8a79d
to
5cabd6c
Compare
This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.