Skip to content

Commit c0cba25

Browse files
[mlir][Transforms] Dialect conversion: Hardening replaceOp (#109540)
This commit adds extra checks/assertions to the `ConversionPatternRewriterImpl::notifyOpReplaced` to improve its robustness. 1. Replacing an `unrealized_conversion_cast` op that was created by the driver is now forbidden and caught early during `replaceOp`. It may work in some cases, but it is generally dangerous because the conversion driver keeps track of these ops and performs some extra legalization steps during the "finalize" phase. (Erasing is them is fine.) 2. `null` replacement values are no longer registered in the `ConversionValueMapping`. This was an oversight in #106760. There is no benefit in having `null` values in the `ConversionValueMapping`. (It may even cause problems.) This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers.
1 parent e19a5fc commit c0cba25

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,16 +1382,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13821382
assert(newValues.size() == op->getNumResults());
13831383
assert(!ignoredOps.contains(op) && "operation was already replaced");
13841384

1385+
// Check if replaced op is an unresolved materialization, i.e., an
1386+
// unrealized_conversion_cast op that was created by the conversion driver.
1387+
bool isUnresolvedMaterialization = false;
1388+
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1389+
if (unresolvedMaterializations.contains(castOp))
1390+
isUnresolvedMaterialization = true;
1391+
13851392
// Create mappings for each of the new result values.
13861393
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13871394
if (!newValue) {
13881395
// This result was dropped and no replacement value was provided.
1389-
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1390-
if (unresolvedMaterializations.contains(castOp)) {
1391-
// Do not create another materializations if we are erasing a
1392-
// materialization.
1393-
continue;
1394-
}
1396+
if (isUnresolvedMaterialization) {
1397+
// Do not create another materializations if we are erasing a
1398+
// materialization.
1399+
continue;
13951400
}
13961401

13971402
// Materialize a replacement value "out of thin air".
@@ -1400,10 +1405,20 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
14001405
result.getLoc(), /*inputs=*/ValueRange(),
14011406
/*outputType=*/result.getType(), /*originalType=*/Type(),
14021407
currentTypeConverter);
1408+
} else {
1409+
// Make sure that the user does not mess with unresolved materializations
1410+
// that were inserted by the conversion driver. We keep track of these
1411+
// ops in internal data structures. Erasing them must be allowed because
1412+
// this can happen when the user is erasing an entire block (including
1413+
// its body). But replacing them with another value should be forbidden
1414+
// to avoid problems with the `mapping`.
1415+
assert(!isUnresolvedMaterialization &&
1416+
"attempting to replace an unresolved materialization");
14031417
}
14041418

1405-
// Remap, and check for any result type changes.
1406-
mapping.map(result, newValue);
1419+
// Remap result to replacement value.
1420+
if (newValue)
1421+
mapping.map(result, newValue);
14071422
}
14081423

14091424
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);

0 commit comments

Comments
 (0)