Skip to content

Commit e67bd4e

Browse files
[mlir][Transforms] Dialect conversion: Unify materialization of value replacements
PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`. This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterates over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.
1 parent 97aa8cc commit e67bd4e

File tree

1 file changed

+42
-92
lines changed

1 file changed

+42
-92
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 42 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,17 +2338,6 @@ struct OperationConverter {
23382338
/// remaining artifacts and complete the conversion.
23392339
LogicalResult finalize(ConversionPatternRewriter &rewriter);
23402340

2341-
/// Legalize the types of converted block arguments.
2342-
LogicalResult
2343-
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2344-
ConversionPatternRewriterImpl &rewriterImpl);
2345-
2346-
/// Legalize the types of converted op results.
2347-
LogicalResult legalizeConvertedOpResultTypes(
2348-
ConversionPatternRewriter &rewriter,
2349-
ConversionPatternRewriterImpl &rewriterImpl,
2350-
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2351-
23522341
/// Dialect conversion configuration.
23532342
ConversionConfig config;
23542343

@@ -2512,19 +2501,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25122501
return success();
25132502
}
25142503

2515-
LogicalResult
2516-
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2517-
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2518-
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2519-
return failure();
2520-
DenseMap<Value, SmallVector<Value>> inverseMapping =
2521-
rewriterImpl.mapping.getInverse();
2522-
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
2523-
inverseMapping)))
2524-
return failure();
2525-
return success();
2526-
}
2527-
25282504
/// Finds a user of the given value, or of any other value that the given value
25292505
/// replaced, that was not replaced in the conversion process.
25302506
static Operation *findLiveUserOfReplaced(
@@ -2548,87 +2524,61 @@ static Operation *findLiveUserOfReplaced(
25482524
return nullptr;
25492525
}
25502526

2551-
LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2552-
ConversionPatternRewriter &rewriter,
2553-
ConversionPatternRewriterImpl &rewriterImpl,
2554-
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2555-
// Process requested operation replacements.
2556-
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2557-
auto *opReplacement =
2558-
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2559-
if (!opReplacement)
2560-
continue;
2561-
Operation *op = opReplacement->getOperation();
2562-
for (OpResult result : op->getResults()) {
2563-
// If the type of this op result changed and the result is still live,
2564-
// we need to materialize a conversion.
2565-
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
2527+
/// Helper function that returns the replaced values and the type converter if
2528+
/// the given rewrite object is an "operation replacement" or a "block type
2529+
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
2530+
/// an empty ValueRange and a null type converter pointer.
2531+
static std::pair<ValueRange, const TypeConverter *>
2532+
getReplacedValues(IRRewrite *rewrite) {
2533+
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2534+
return std::make_pair(opRewrite->getOperation()->getResults(),
2535+
opRewrite->getConverter());
2536+
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2537+
return std::make_pair(blockRewrite->getOrigBlock()->getArguments(),
2538+
blockRewrite->getConverter());
2539+
return std::make_pair(ValueRange(), nullptr);
2540+
}
2541+
2542+
LogicalResult
2543+
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2544+
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2545+
DenseMap<Value, SmallVector<Value>> inverseMapping =
2546+
rewriterImpl.mapping.getInverse();
2547+
2548+
// Process requested value replacements.
2549+
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2550+
ValueRange replacedValues;
2551+
const TypeConverter *converter;
2552+
std::tie(replacedValues, converter) =
2553+
getReplacedValues(rewriterImpl.rewrites[i].get());
2554+
for (Value originalValue : replacedValues) {
2555+
// If the type of this value changed and the value is still live, we need
2556+
// to materialize a conversion.
2557+
if (rewriterImpl.mapping.lookupOrNull(originalValue,
2558+
originalValue.getType()))
25662559
continue;
25672560
Operation *liveUser =
2568-
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2561+
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
25692562
if (!liveUser)
25702563
continue;
25712564

2572-
// Legalize this result.
2573-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2565+
// Legalize this value replacement.
2566+
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
25742567
assert(newValue && "replacement value not found");
25752568
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2576-
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
2577-
/*inputs=*/newValue, /*outputType=*/result.getType(),
2578-
opReplacement->getConverter());
2579-
rewriterImpl.mapping.map(result, castValue);
2580-
inverseMapping[castValue].push_back(result);
2581-
llvm::erase(inverseMapping[newValue], result);
2569+
MaterializationKind::Source, computeInsertPoint(newValue),
2570+
originalValue.getLoc(),
2571+
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2572+
converter);
2573+
rewriterImpl.mapping.map(originalValue, castValue);
2574+
inverseMapping[castValue].push_back(originalValue);
2575+
llvm::erase(inverseMapping[newValue], originalValue);
25822576
}
25832577
}
25842578

25852579
return success();
25862580
}
25872581

2588-
LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2589-
ConversionPatternRewriter &rewriter,
2590-
ConversionPatternRewriterImpl &rewriterImpl) {
2591-
// Functor used to check if all users of a value will be dead after
2592-
// conversion.
2593-
// TODO: This should probably query the inverse mapping, same as in
2594-
// `legalizeConvertedOpResultTypes`.
2595-
auto findLiveUser = [&](Value val) {
2596-
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2597-
return rewriterImpl.isOpIgnored(user);
2598-
});
2599-
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2600-
};
2601-
// Note: `rewrites` may be reallocated as the loop is running.
2602-
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2603-
++i) {
2604-
auto &rewrite = rewriterImpl.rewrites[i];
2605-
if (auto *blockTypeConversionRewrite =
2606-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2607-
// Process the remapping for each of the original arguments.
2608-
for (Value origArg :
2609-
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2610-
// If the type of this argument changed and the argument is still live,
2611-
// we need to materialize a conversion.
2612-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2613-
continue;
2614-
Operation *liveUser = findLiveUser(origArg);
2615-
if (!liveUser)
2616-
continue;
2617-
2618-
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2619-
assert(replacementValue && "replacement value not found");
2620-
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2621-
MaterializationKind::Source, computeInsertPoint(replacementValue),
2622-
origArg.getLoc(), /*inputs=*/replacementValue,
2623-
/*outputType=*/origArg.getType(),
2624-
blockTypeConversionRewrite->getConverter());
2625-
rewriterImpl.mapping.map(origArg, repl);
2626-
}
2627-
}
2628-
}
2629-
return success();
2630-
}
2631-
26322582
//===----------------------------------------------------------------------===//
26332583
// Reconcile Unrealized Casts
26342584
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)