Skip to content

Commit d72b58e

Browse files
[mlir][Transforms] Dialect conversion: Align handling of dropped values
Handle dropped block arguments and dropped op results in the same way: build a source materialization (that may fold away if unused). This simplifies the code base a bit and makes it possible to merge `legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes` in a future commit. These two functions are almost doing the same thing now. This commit also fixes a bug where circular materializations were built, e.g.: ``` %0 = "builtin.unrealized_conversion_cast"(%1) : (!a) -> !b %1 = "builtin.unrealized_conversion_cast"(%0) : (!b) -> !a // No further uses of %0, %1. ``` This happened when: 1. An op was erased. (No replacement values provided.) 2. A conversion pattern for another op builds a replacement value (first cast op) during `remapValues`, but that SSA value is not used during the pattern application. 3. During the finalization phase, `legalizeConvertedOpResultTypes` thinks that the erased op is alive because of the cast op that was built in Step 2. It builds a cast from that replacement value to the original type. 4. During the commit phase, all uses of the original op are repalced with the casted value produced in Step 3. We have generated circular IR.
1 parent 57fe53c commit d72b58e

File tree

2 files changed

+24
-41
lines changed

2 files changed

+24
-41
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
930930
/// to modify/access them is invalid rewriter API usage.
931931
SetVector<Operation *> replacedOps;
932932

933+
DenseSet<Operation *> unresolvedMaterializations;
934+
933935
/// The current type converter, or nullptr if no type converter is currently
934936
/// active.
935937
const TypeConverter *currentTypeConverter = nullptr;
@@ -1055,6 +1057,7 @@ void UnresolvedMaterializationRewrite::rollback() {
10551057
for (Value input : op->getOperands())
10561058
rewriterImpl.mapping.erase(input);
10571059
}
1060+
rewriterImpl.unresolvedMaterializations.erase(op);
10581061
op->erase();
10591062
}
10601063

@@ -1341,6 +1344,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13411344
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13421345
auto convertOp =
13431346
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1347+
unresolvedMaterializations.insert(convertOp);
13441348
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13451349
return convertOp.getResult(0);
13461350
}
@@ -1379,9 +1383,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13791383
// Create mappings for each of the new result values.
13801384
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13811385
if (!newValue) {
1382-
resultChanged = true;
1383-
continue;
1386+
// This result was dropped and no replacement value was provided.
1387+
if (unresolvedMaterializations.contains(op)) {
1388+
// Do not create another materializations if we are erasing a
1389+
// materialization.
1390+
resultChanged = true;
1391+
continue;
1392+
}
1393+
1394+
// Materialize a replacement value "out of thin air".
1395+
newValue = buildUnresolvedMaterialization(
1396+
MaterializationKind::Source, computeInsertPoint(result),
1397+
result.getLoc(), /*inputs=*/ValueRange(),
1398+
/*outputType=*/result.getType(), currentTypeConverter);
13841399
}
1400+
13851401
// Remap, and check for any result type changes.
13861402
mapping.map(result, newValue);
13871403
resultChanged |= (newValue.getType() != result.getType());
@@ -2359,11 +2375,6 @@ struct OperationConverter {
23592375
ConversionPatternRewriterImpl &rewriterImpl,
23602376
DenseMap<Value, SmallVector<Value>> &inverseMapping);
23612377

2362-
/// Legalize an operation result that was marked as "erased".
2363-
LogicalResult
2364-
legalizeErasedResult(Operation *op, OpResult result,
2365-
ConversionPatternRewriterImpl &rewriterImpl);
2366-
23672378
/// Dialect conversion configuration.
23682379
ConversionConfig config;
23692380

@@ -2500,26 +2511,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
25002511
continue;
25012512
Operation *op = opReplacement->getOperation();
25022513
for (OpResult result : op->getResults()) {
2503-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2504-
2505-
// If the operation result was replaced with null, all of the uses of this
2506-
// value should be replaced.
2507-
if (!newValue) {
2508-
if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2509-
return failure();
2514+
// If the type of this op result changed and the result is still live,
2515+
// we need to materialize a conversion.
2516+
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
25102517
continue;
2511-
}
2512-
2513-
// Otherwise, check to see if the type of the result changed.
2514-
if (result.getType() == newValue.getType())
2515-
continue;
2516-
25172518
Operation *liveUser =
25182519
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
25192520
if (!liveUser)
25202521
continue;
25212522

25222523
// Legalize this result.
2524+
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2525+
assert(newValue && "replacement value not found");
25232526
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
25242527
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
25252528
/*inputs=*/newValue, /*outputType=*/result.getType(),
@@ -2850,25 +2853,6 @@ LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
28502853
return success();
28512854
}
28522855

2853-
LogicalResult OperationConverter::legalizeErasedResult(
2854-
Operation *op, OpResult result,
2855-
ConversionPatternRewriterImpl &rewriterImpl) {
2856-
// If the operation result was replaced with null, all of the uses of this
2857-
// value should be replaced.
2858-
auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2859-
return rewriterImpl.isOpIgnored(user);
2860-
});
2861-
if (liveUserIt != result.user_end()) {
2862-
InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2863-
<< op->getName() << "' marked as erased";
2864-
diag.attachNote(liveUserIt->getLoc())
2865-
<< "found live user of result #" << result.getResultNumber() << ": "
2866-
<< *liveUserIt;
2867-
return failure();
2868-
}
2869-
return success();
2870-
}
2871-
28722856
//===----------------------------------------------------------------------===//
28732857
// Reconcile Unrealized Casts
28742858
//===----------------------------------------------------------------------===//

mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
// Test that an error is emitted when an operation is marked as "erased", but
44
// has users that live across the conversion.
55
func.func @remove_all_ops(%arg0: i32) -> i32 {
6-
// expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
6+
// expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}}
77
%0 = "test.illegal_op_a"() : () -> i32
8-
// expected-note@below {{found live user of result #0: func.return %0 : i32}}
98
return %0 : i32
109
}

0 commit comments

Comments
 (0)