diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index b58a95c3baf70..caea9e111afed 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, - MaterializationKind kind = MaterializationKind::Target) - : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind) {} + MaterializationKind kind = MaterializationKind::Target); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) { }); } -/// Find the single rewrite object of the specified type and block among the -/// given rewrites. In debug mode, asserts that there is mo more than one such -/// object. Return "nullptr" if no object was found. -template -static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { - RewriteTy *result = nullptr; - for (auto &rewrite : rewrites) { - auto *rewriteTy = dyn_cast(rewrite.get()); - if (rewriteTy && rewriteTy->getBlock() == block) { -#ifndef NDEBUG - assert(!result && "expected single matching rewrite"); - result = rewriteTy; -#else - return rewriteTy; -#endif // NDEBUG - } - } - return result; -} - //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasErased(void *ptr) const { return erased.contains(ptr); } - bool wasErased(OperationRewrite *rewrite) const { - return wasErased(rewrite->getOperation()); - } - void notifyOperationErased(Operation *op) override { erased.insert(op); } void notifyBlockErased(Block *block) override { erased.insert(block); } @@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector replacedOps; - /// A set of all unresolved materializations. - DenseSet unresolvedMaterializations; + /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) + /// to the corresponding rewrite objects. + DenseMap + unresolvedMaterializations; /// The current type converter, or nullptr if no type converter is currently /// active. @@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() { op->erase(); } +UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( + ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, + const TypeConverter *converter, MaterializationKind kind) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + converterAndKind(converter, kind) { + rewriterImpl.unresolvedMaterializations[op] = this; +} + void UnresolvedMaterializationRewrite::rollback() { if (getMaterializationKind() == MaterializationKind::Target) { for (Value input : op->getOperands()) rewriterImpl.mapping.erase(input); } - rewriterImpl.unresolvedMaterializations.erase(op); + rewriterImpl.unresolvedMaterializations.erase(getOperation()); op->erase(); } @@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create(loc, outputType, inputs); - unresolvedMaterializations.insert(convertOp); appendRewrite(convertOp, converter, kind); return convertOp.getResult(0); } @@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { if (!newValue) { // This result was dropped and no replacement value was provided. - if (unresolvedMaterializations.contains(op)) { - // Do not create another materializations if we are erasing a - // materialization. - continue; + if (auto castOp = dyn_cast(op)) { + if (unresolvedMaterializations.contains(castOp)) { + // Do not create another materializations if we are erasing a + // materialization. + continue; + } } // Materialize a replacement value "out of thin air". @@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Gather all unresolved materializations. SmallVector allCastOps; - DenseMap rewriteMap; - for (std::unique_ptr &rewrite : rewriterImpl.rewrites) { - auto *mat = dyn_cast(rewrite.get()); - if (!mat) - continue; - if (rewriterImpl.eraseRewriter.wasErased(mat)) + const DenseMap + &materializations = rewriterImpl.unresolvedMaterializations; + for (auto it : materializations) { + if (rewriterImpl.eraseRewriter.wasErased(it.first)) continue; - allCastOps.push_back(mat->getOperation()); - rewriteMap[mat->getOperation()] = mat; + allCastOps.push_back(it.first); } // Reconcile all UnrealizedConversionCastOps that were inserted by the @@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (config.buildMaterializations) { IRRewriter rewriter(rewriterImpl.context, config.listener); for (UnrealizedConversionCastOp castOp : remainingCastOps) { - auto it = rewriteMap.find(castOp.getOperation()); - assert(it != rewriteMap.end() && "inconsistent state"); + auto it = materializations.find(castOp); + assert(it != materializations.end() && "inconsistent state"); if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) return failure(); }