Skip to content

Commit 01ba3d9

Browse files
[mlir][Transforms] Dialect conversion: Extra checks during replaceOp
This commit adds extra checks/assertions to the `ConversionPatternRewriterImpl::notifyOpReplaced` to improve its robustness. Replacing an `unrealized_conversion_cast` op that was created by the driver is forbidden and is now caught early during `replaceOp`. It may work in some cases, but is generally dangerous because the conversion driver keeps track of these ops. (Erasing is them is fine.) This change is also in preparation of a subsequent commit that splits the `ConversionValueMapping` into replacements and materializations (with the goal of simplifying block signature conversions). `null` replacement values are no longer registered in the `ConversionValueMapping`. This was an oversight in #106760. `null` values in the mapping could result in crashes when using the `ConversionValueMapping` API.
1 parent 6cd7979 commit 01ba3d9

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,16 +1382,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13821382
assert(newValues.size() == op->getNumResults());
13831383
assert(!ignoredOps.contains(op) && "operation was already replaced");
13841384

1385+
// Check if replaced op is an unresolved materialization, i.e., an
1386+
// unrealized_conversion_cast op that was created by the conversion driver.
1387+
bool isUnresolvedMaterialization = false;
1388+
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1389+
if (unresolvedMaterializations.contains(castOp))
1390+
isUnresolvedMaterialization = true;
1391+
13851392
// Create mappings for each of the new result values.
13861393
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13871394
if (!newValue) {
13881395
// This result was dropped and no replacement value was provided.
1389-
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1390-
if (unresolvedMaterializations.contains(castOp)) {
1391-
// Do not create another materializations if we are erasing a
1392-
// materialization.
1393-
continue;
1394-
}
1396+
if (isUnresolvedMaterialization) {
1397+
// Do not create another materializations if we are erasing a
1398+
// materialization.
1399+
continue;
13951400
}
13961401

13971402
// Materialize a replacement value "out of thin air".
@@ -1400,10 +1405,20 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
14001405
result.getLoc(), /*inputs=*/ValueRange(),
14011406
/*outputType=*/result.getType(), /*originalType=*/Type(),
14021407
currentTypeConverter);
1408+
} else {
1409+
// Make sure that the user does not mess with unresolved materializations
1410+
// that were inserted by the conversion driver. We keep track of these
1411+
// ops in internal data structures. Erasing them must be allowed because
1412+
// this can happen when the user is erasing an entire block (including
1413+
// its body). But replacing them with another value should be forbidden
1414+
// to avoid problems with the `mapping`.
1415+
assert(!isUnresolvedMaterialization &&
1416+
"attempting to replace an unresolved materialization");
14031417
}
14041418

1405-
// Remap, and check for any result type changes.
1406-
mapping.map(result, newValue);
1419+
// Remap result to replacement value.
1420+
if (newValue)
1421+
mapping.map(result, newValue);
14071422
}
14081423

14091424
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);

0 commit comments

Comments
 (0)