Skip to content

Commit f47ea8d

Browse files
[mlir][Transforms] Dialect conversion: Align handling of dropped values
Handle dropped block arguments and dropped op results in the same way: build a source materialization (that may fold away if unused). This simplifies the code base a bit and makes it possible to merge `legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes` in a future commit. These two functions are almost doing the same thing now. This commit also fixes a bug where circular materializations were built, e.g.: ``` %0 = "builtin.unrealized_conversion_cast"(%1) : (!a) -> !b %1 = "builtin.unrealized_conversion_cast"(%0) : (!b) -> !a // No further uses of %0, %1. ``` This happened when: 1. An op was erased. (No replacement values provided.) 2. A conversion pattern for another op builds a replacement value (first cast op) during `remapValues`, but that SSA value is not used during the pattern application. 3. During the finalization phase, `legalizeConvertedOpResultTypes` thinks that the erased op is alive because of the cast op that was built in Step 2. It builds a cast from that replacement value to the original type. 4. During the commit phase, all uses of the original op are repalced with the casted value produced in Step 3. We have generated circular IR.
1 parent c2e53b2 commit f47ea8d

File tree

2 files changed

+28
-117
lines changed

2 files changed

+28
-117
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 26 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
941941
/// to modify/access them is invalid rewriter API usage.
942942
SetVector<Operation *> replacedOps;
943943

944+
DenseSet<Operation *> unresolvedMaterializations;
945+
944946
/// The current type converter, or nullptr if no type converter is currently
945947
/// active.
946948
const TypeConverter *currentTypeConverter = nullptr;
@@ -1066,6 +1068,7 @@ void UnresolvedMaterializationRewrite::rollback() {
10661068
for (Value input : op->getOperands())
10671069
rewriterImpl.mapping.erase(input);
10681070
}
1071+
rewriterImpl.unresolvedMaterializations.erase(op);
10691072
op->erase();
10701073
}
10711074

@@ -1347,6 +1350,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13471350
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13481351
auto convertOp =
13491352
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1353+
unresolvedMaterializations.insert(convertOp);
13501354
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13511355
return convertOp.getResult(0);
13521356
}
@@ -1385,9 +1389,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13851389
// Create mappings for each of the new result values.
13861390
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13871391
if (!newValue) {
1388-
resultChanged = true;
1389-
continue;
1392+
// This result was dropped and no replacement value was provided.
1393+
if (unresolvedMaterializations.contains(op)) {
1394+
// Do not create another materializations if we are erasing a
1395+
// materialization.
1396+
resultChanged = true;
1397+
continue;
1398+
}
1399+
1400+
// Materialize a replacement value "out of thin air".
1401+
newValue = buildUnresolvedMaterialization(
1402+
MaterializationKind::Source, computeInsertPoint(result),
1403+
result.getLoc(), /*inputs=*/ValueRange(),
1404+
/*outputType=*/result.getType(), currentTypeConverter);
13901405
}
1406+
13911407
// Remap, and check for any result type changes.
13921408
mapping.map(result, newValue);
13931409
resultChanged |= (newValue.getType() != result.getType());
@@ -2359,11 +2375,6 @@ struct OperationConverter {
23592375
ConversionPatternRewriterImpl &rewriterImpl,
23602376
DenseMap<Value, SmallVector<Value>> &inverseMapping);
23612377

2362-
/// Legalize an operation result that was marked as "erased".
2363-
LogicalResult
2364-
legalizeErasedResult(Operation *op, OpResult result,
2365-
ConversionPatternRewriterImpl &rewriterImpl);
2366-
23672378
/// Dialect conversion configuration.
23682379
ConversionConfig config;
23692380

@@ -2455,78 +2466,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
24552466
return failure();
24562467
}
24572468

2458-
/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results
2459-
/// are not used (transitively) by any op that is not in the given list of
2460-
/// cast ops.
2461-
///
2462-
/// In particular, this function erases cyclic casts that may be inserted
2463-
/// during the dialect conversion process. E.g.:
2464-
/// %0 = unrealized_conversion_cast(%1)
2465-
/// %1 = unrealized_conversion_cast(%0)
2466-
// Note: This step will become unnecessary when
2467-
// https://github.com/llvm/llvm-project/pull/106760 has been merged.
2468-
static void eraseDeadUnrealizedCasts(
2469-
ArrayRef<UnrealizedConversionCastOp> castOps,
2470-
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2471-
// Ops that have already been visited or are currently being visited.
2472-
DenseSet<Operation *> visited;
2473-
// Set of all cast ops for faster lookups.
2474-
DenseSet<Operation *> castOpSet;
2475-
// Set of all cast ops that have been determined to be alive.
2476-
DenseSet<Operation *> live;
2477-
2478-
for (UnrealizedConversionCastOp op : castOps)
2479-
castOpSet.insert(op);
2480-
2481-
// Visit a cast operation. Return "true" if the operation is live.
2482-
std::function<bool(Operation *)> visit = [&](Operation *op) -> bool {
2483-
// No need to traverse any IR if the op was already marked as live.
2484-
if (live.contains(op))
2485-
return true;
2486-
2487-
// Do not visit ops multiple times. If we find a circle, no live user was
2488-
// found on the current path.
2489-
if (visited.contains(op))
2490-
return false;
2491-
visited.insert(op);
2492-
2493-
// Visit all users.
2494-
for (Operation *user : op->getUsers()) {
2495-
// If the user is not an unrealized_conversion_cast op, then the given op
2496-
// is live.
2497-
if (!castOpSet.contains(user)) {
2498-
live.insert(op);
2499-
return true;
2500-
}
2501-
// Otherwise, it is live if a live op can be reached from one of its
2502-
// users (which must all be unrealized_conversion_cast ops).
2503-
if (visit(user)) {
2504-
live.insert(op);
2505-
return true;
2506-
}
2507-
}
2508-
2509-
return false;
2510-
};
2511-
2512-
// Visit all cast ops.
2513-
for (UnrealizedConversionCastOp op : castOps) {
2514-
visit(op);
2515-
visited.clear();
2516-
}
2517-
2518-
// Erase all cast ops that are dead.
2519-
for (UnrealizedConversionCastOp op : castOps) {
2520-
if (live.contains(op)) {
2521-
if (remainingCastOps)
2522-
remainingCastOps->push_back(op);
2523-
continue;
2524-
}
2525-
op->dropAllUses();
2526-
op->erase();
2527-
}
2528-
}
2529-
25302469
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25312470
if (ops.empty())
25322471
return success();
@@ -2585,14 +2524,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25852524
// Reconcile all UnrealizedConversionCastOps that were inserted by the
25862525
// dialect conversion frameworks. (Not the one that were inserted by
25872526
// patterns.)
2588-
SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
2589-
eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1);
2590-
reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2);
2527+
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2528+
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
25912529

25922530
// Try to legalize all unresolved materializations.
25932531
if (config.buildMaterializations) {
25942532
IRRewriter rewriter(rewriterImpl.context, config.listener);
2595-
for (UnrealizedConversionCastOp castOp : remainingCastOps2) {
2533+
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
25962534
auto it = rewriteMap.find(castOp.getOperation());
25972535
assert(it != rewriteMap.end() && "inconsistent state");
25982536
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
@@ -2651,26 +2589,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
26512589
continue;
26522590
Operation *op = opReplacement->getOperation();
26532591
for (OpResult result : op->getResults()) {
2654-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2655-
2656-
// If the operation result was replaced with null, all of the uses of this
2657-
// value should be replaced.
2658-
if (!newValue) {
2659-
if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2660-
return failure();
2661-
continue;
2662-
}
2663-
2664-
// Otherwise, check to see if the type of the result changed.
2665-
if (result.getType() == newValue.getType())
2592+
// If the type of this op result changed and the result is still live,
2593+
// we need to materialize a conversion.
2594+
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
26662595
continue;
2667-
26682596
Operation *liveUser =
26692597
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
26702598
if (!liveUser)
26712599
continue;
26722600

26732601
// Legalize this result.
2602+
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2603+
assert(newValue && "replacement value not found");
26742604
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
26752605
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
26762606
/*inputs=*/newValue, /*outputType=*/result.getType(),
@@ -2728,25 +2658,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
27282658
return success();
27292659
}
27302660

2731-
LogicalResult OperationConverter::legalizeErasedResult(
2732-
Operation *op, OpResult result,
2733-
ConversionPatternRewriterImpl &rewriterImpl) {
2734-
// If the operation result was replaced with null, all of the uses of this
2735-
// value should be replaced.
2736-
auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2737-
return rewriterImpl.isOpIgnored(user);
2738-
});
2739-
if (liveUserIt != result.user_end()) {
2740-
InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2741-
<< op->getName() << "' marked as erased";
2742-
diag.attachNote(liveUserIt->getLoc())
2743-
<< "found live user of result #" << result.getResultNumber() << ": "
2744-
<< *liveUserIt;
2745-
return failure();
2746-
}
2747-
return success();
2748-
}
2749-
27502661
//===----------------------------------------------------------------------===//
27512662
// Reconcile Unrealized Casts
27522663
//===----------------------------------------------------------------------===//

mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
// Test that an error is emitted when an operation is marked as "erased", but
44
// has users that live across the conversion.
55
func.func @remove_all_ops(%arg0: i32) -> i32 {
6-
// expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
6+
// expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}}
77
%0 = "test.illegal_op_a"() : () -> i32
8-
// expected-note@below {{found live user of result #0: func.return %0 : i32}}
8+
// expected-note@below {{see existing live user here}}
99
return %0 : i32
1010
}

0 commit comments

Comments
 (0)