Skip to content

Commit c45cdc9

Browse files
[mlir][Transforms] Dialect conversion: Extra checks during replaceOp
This commit adds extra checks/assertions to the `ConversionPatternRewriterImpl::notifyOpReplaced` to improve its robustness. Replacing an `unrealized_conversion_cast` op that was created by the driver is forbidden and is now caught early during `replaceOp`. It may work in some cases, but is generally dangerous because the conversion driver keeps track of these ops. (Erasing is them is fine.) This change is also in preparation of a subsequent commit that splits the `ConversionValueMapping` into replacements and materializations (with the goal of simplifying block signature conversions). `null` replacement values are no longer registered in the `ConversionValueMapping`. This was an oversight in #106760. `null` values in the mapping could result in crashes when using the `ConversionValueMapping` API.
1 parent 8527861 commit c45cdc9

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,27 +1361,42 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13611361
assert(newValues.size() == op->getNumResults());
13621362
assert(!ignoredOps.contains(op) && "operation was already replaced");
13631363

1364+
// Check if replaced op is an unresolved materialization, i.e., an
1365+
// unrealized_conversion_cast op that was created by the conversion driver.
1366+
bool isUnresolvedMaterialization = false;
1367+
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1368+
if (unresolvedMaterializations.contains(castOp))
1369+
isUnresolvedMaterialization = true;
1370+
13641371
// Create mappings for each of the new result values.
13651372
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13661373
if (!newValue) {
13671374
// This result was dropped and no replacement value was provided.
1368-
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1369-
if (unresolvedMaterializations.contains(castOp)) {
1370-
// Do not create another materializations if we are erasing a
1371-
// materialization.
1372-
continue;
1373-
}
1375+
if (isUnresolvedMaterialization) {
1376+
// Do not create another materializations if we are erasing a
1377+
// materialization.
1378+
continue;
13741379
}
13751380

13761381
// Materialize a replacement value "out of thin air".
13771382
newValue = buildUnresolvedMaterialization(
13781383
MaterializationKind::Source, computeInsertPoint(result),
13791384
result.getLoc(), /*inputs=*/ValueRange(),
13801385
/*outputType=*/result.getType(), currentTypeConverter);
1386+
} else {
1387+
// Make sure that the user does not mess with unresolved materializations
1388+
// that were inserted by the conversion driver. We keep track of these
1389+
// ops in internal data structures. Erasing them must be allowed because
1390+
// this can happen when the user is erasing an entire block (including
1391+
// its body). But replacing them with another value should be forbidden
1392+
// to avoid problems with the `mapping`.
1393+
assert(!isUnresolvedMaterialization &&
1394+
"attempting to replace an unresolved materialization");
13811395
}
13821396

1383-
// Remap, and check for any result type changes.
1384-
mapping.map(result, newValue);
1397+
// Remap result to replacement value.
1398+
if (newValue)
1399+
mapping.map(result, newValue);
13851400
}
13861401

13871402
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);

0 commit comments

Comments
 (0)