Skip to content

[mlir][Transforms][NFC] Turn in-place op modification into IRRewrite #81245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {

/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require
/// notification through other more specific hooks above.
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
void startOpModification(Operation *op) override;

/// PatternRewriter hook for updating the given operation in-place.
Expand Down
146 changes: 72 additions & 74 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,12 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
unsigned numRewrites, unsigned numIgnoredOperations,
unsigned numRootUpdates)
unsigned numRewrites, unsigned numIgnoredOperations)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
numIgnoredOperations(numIgnoredOperations),
numRootUpdates(numRootUpdates) {}
numIgnoredOperations(numIgnoredOperations) {}

/// The current number of created operations.
unsigned numCreatedOps;
Expand All @@ -180,44 +178,6 @@ struct RewriterState {

/// The current number of ignored operations.
unsigned numIgnoredOperations;

/// The current number of operations that were updated in place.
unsigned numRootUpdates;
};

//===----------------------------------------------------------------------===//
// OperationTransactionState

/// The state of an operation that was updated by a pattern in-place. This
/// contains all of the necessary information to reconstruct an operation that
/// was updated in place.
class OperationTransactionState {
public:
OperationTransactionState() = default;
OperationTransactionState(Operation *op)
: op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
operands(op->operand_begin(), op->operand_end()),
successors(op->successor_begin(), op->successor_end()) {}

/// Discard the transaction state and reset the state of the original
/// operation.
void resetOperation() const {
op->setLoc(loc);
op->setAttrs(attrs);
op->setOperands(operands);
for (const auto &it : llvm::enumerate(successors))
op->setSuccessor(it.value(), it.index());
}

/// Return the original operation of this state.
Operation *getOperation() const { return op; }

private:
Operation *op;
LocationAttr loc;
DictionaryAttr attrs;
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -754,14 +714,19 @@ namespace {
class IRRewrite {
public:
/// The kind of the rewrite. Rewrites can be undone if the conversion fails.
/// Enum values are ordered, so that they can be used in `classof`: first all
/// block rewrites, then all operation rewrites.
enum class Kind {
// Block rewrites
CreateBlock,
EraseBlock,
InlineBlock,
MoveBlock,
SplitBlock,
BlockTypeConversion,
MoveOperation
// Operation rewrites
MoveOperation,
ModifyOperation
};

virtual ~IRRewrite() = default;
Expand Down Expand Up @@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite {

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
rewrite->getKind() <= Kind::MoveOperation;
rewrite->getKind() <= Kind::ModifyOperation;
}

protected:
Expand Down Expand Up @@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite {
// this operation was the only operation in the region.
Operation *insertBeforeOp;
};

/// In-place modification of an op. This rewrite is immediately reflected in
/// the IR. The previous state of the operation is stored in this object.
class ModifyOperationRewrite : public OperationRewrite {
public:
ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Operation *op)
: OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
loc(op->getLoc()), attrs(op->getAttrDictionary()),
operands(op->operand_begin(), op->operand_end()),
successors(op->successor_begin(), op->successor_end()) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::ModifyOperation;
}

void rollback() override {
op->setLoc(loc);
op->setAttrs(attrs);
op->setOperands(operands);
for (const auto &it : llvm::enumerate(successors))
op->setSuccessor(it.value(), it.index());
}

private:
LocationAttr loc;
DictionaryAttr attrs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is properties needed too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a separate PR (#82474), so that this PR can stay NFC.

SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
};
} // namespace

/// Return "true" if there is an operation rewrite that matches the specified
/// rewrite type and operation among the given rewrites.
template <typename RewriteTy, typename R>
static bool hasRewrite(R &&rewrites, Operation *op) {
return any_of(std::move(rewrites), [&](auto &rewrite) {
auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
return rewriteTy && rewriteTy->getOperation() == op;
});
}

//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;

/// A transaction state for each of operations that were updated in-place.
SmallVector<OperationTransactionState, 4> rootUpdates;

/// A vector of indices into `replacements` of operations that were replaced
/// with values with different result types than the original operation, e.g.
/// 1->N conversion of some kind.
Expand Down Expand Up @@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) {
}

void ConversionPatternRewriterImpl::discardRewrites() {
// Reset any operations that were updated in place.
for (auto &state : rootUpdates)
state.resetOperation();

undoRewrites();

// Remove any newly created ops.
Expand Down Expand Up @@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
replacements.size(), argReplacements.size(),
rewrites.size(), ignoredOps.size(), rootUpdates.size());
rewrites.size(), ignoredOps.size());
}

void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Reset any operations that were updated in place.
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
rootUpdates[i].resetOperation();
rootUpdates.resize(state.numRootUpdates);

// Reset any replaced arguments.
for (BlockArgument replacedArg :
llvm::drop_begin(argReplacements, state.numArgReplacements))
Expand Down Expand Up @@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
impl->rootUpdates.emplace_back(op);
impl->appendRewrite<ModifyOperationRewrite>(op);
}

void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
Expand All @@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
"operation did not have a pending in-place update");
#endif
// Erase the last update for this operation.
auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
auto &rootUpdates = impl->rootUpdates;
auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
assert(it != rootUpdates.rend() && "no root update started on op");
(*it).resetOperation();
int updateIdx = std::prev(rootUpdates.rend()) - it;
rootUpdates.erase(rootUpdates.begin() + updateIdx);
auto it = llvm::find_if(
llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
return modifyRewrite && modifyRewrite->getOperation() == op;
});
assert(it != impl->rewrites.rend() && "no root update started on op");
(*it)->rollback();
int updateIdx = std::prev(impl->rewrites.rend()) - it;
impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
}

detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
Expand Down Expand Up @@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that cleans up the rewriter state after a pattern failed to match.
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.notifyCallback) {
Expand All @@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that performs additional legalization when a pattern is
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
auto result = legalizePatternResult(op, pattern, rewriter, curState);
appliedPatterns.erase(&pattern);
if (failed(result))
Expand Down Expand Up @@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,

#ifndef NDEBUG
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
#endif

// Check that the root was either replaced or updated in place.
auto replacedRoot = [&] {
Expand All @@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
[op](auto &it) { return it.first == op; });
};
auto updatedRootInPlace = [&] {
return llvm::any_of(
llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
[op](auto &state) { return state.getOperation() == op; });
return hasRewrite<ModifyOperationRewrite>(
llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
};
(void)replacedRoot;
(void)updatedRootInPlace;
assert((replacedRoot() || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
#endif // NDEBUG

// Legalize each of the actions registered during application.
RewriterState newState = impl.getCurrentState();
Expand Down Expand Up @@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
Operation *op = impl.rootUpdates[i].getOperation();
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
if (!rewrite)
continue;
Operation *op = rewrite->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "failed to legalize operation updated in-place '{0}'",
Expand Down Expand Up @@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
// Full Conversion

LogicalResult
mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns) {
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
return opConverter.convertOperations(ops);
Expand Down