@@ -941,6 +941,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
941
941
// / to modify/access them is invalid rewriter API usage.
942
942
SetVector<Operation *> replacedOps;
943
943
944
+ DenseSet<Operation *> unresolvedMaterializations;
945
+
944
946
// / The current type converter, or nullptr if no type converter is currently
945
947
// / active.
946
948
const TypeConverter *currentTypeConverter = nullptr ;
@@ -1066,6 +1068,7 @@ void UnresolvedMaterializationRewrite::rollback() {
1066
1068
for (Value input : op->getOperands ())
1067
1069
rewriterImpl.mapping .erase (input);
1068
1070
}
1071
+ rewriterImpl.unresolvedMaterializations .erase (op);
1069
1072
op->erase ();
1070
1073
}
1071
1074
@@ -1347,6 +1350,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1347
1350
builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
1348
1351
auto convertOp =
1349
1352
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1353
+ unresolvedMaterializations.insert (convertOp);
1350
1354
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1351
1355
return convertOp.getResult (0 );
1352
1356
}
@@ -1385,9 +1389,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1385
1389
// Create mappings for each of the new result values.
1386
1390
for (auto [newValue, result] : llvm::zip (newValues, op->getResults ())) {
1387
1391
if (!newValue) {
1388
- resultChanged = true ;
1389
- continue ;
1392
+ // This result was dropped and no replacement value was provided.
1393
+ if (unresolvedMaterializations.contains (op)) {
1394
+ // Do not create another materializations if we are erasing a
1395
+ // materialization.
1396
+ resultChanged = true ;
1397
+ continue ;
1398
+ }
1399
+
1400
+ // Materialize a replacement value "out of thin air".
1401
+ newValue = buildUnresolvedMaterialization (
1402
+ MaterializationKind::Source, computeInsertPoint (result),
1403
+ result.getLoc (), /* inputs=*/ ValueRange (),
1404
+ /* outputType=*/ result.getType (), currentTypeConverter);
1390
1405
}
1406
+
1391
1407
// Remap, and check for any result type changes.
1392
1408
mapping.map (result, newValue);
1393
1409
resultChanged |= (newValue.getType () != result.getType ());
@@ -2359,11 +2375,6 @@ struct OperationConverter {
2359
2375
ConversionPatternRewriterImpl &rewriterImpl,
2360
2376
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2361
2377
2362
- // / Legalize an operation result that was marked as "erased".
2363
- LogicalResult
2364
- legalizeErasedResult (Operation *op, OpResult result,
2365
- ConversionPatternRewriterImpl &rewriterImpl);
2366
-
2367
2378
// / Dialect conversion configuration.
2368
2379
ConversionConfig config;
2369
2380
@@ -2455,78 +2466,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2455
2466
return failure ();
2456
2467
}
2457
2468
2458
- // / Erase all dead unrealized_conversion_cast ops. An op is dead if its results
2459
- // / are not used (transitively) by any op that is not in the given list of
2460
- // / cast ops.
2461
- // /
2462
- // / In particular, this function erases cyclic casts that may be inserted
2463
- // / during the dialect conversion process. E.g.:
2464
- // / %0 = unrealized_conversion_cast(%1)
2465
- // / %1 = unrealized_conversion_cast(%0)
2466
- // Note: This step will become unnecessary when
2467
- // https://github.com/llvm/llvm-project/pull/106760 has been merged.
2468
- static void eraseDeadUnrealizedCasts (
2469
- ArrayRef<UnrealizedConversionCastOp> castOps,
2470
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2471
- // Ops that have already been visited or are currently being visited.
2472
- DenseSet<Operation *> visited;
2473
- // Set of all cast ops for faster lookups.
2474
- DenseSet<Operation *> castOpSet;
2475
- // Set of all cast ops that have been determined to be alive.
2476
- DenseSet<Operation *> live;
2477
-
2478
- for (UnrealizedConversionCastOp op : castOps)
2479
- castOpSet.insert (op);
2480
-
2481
- // Visit a cast operation. Return "true" if the operation is live.
2482
- std::function<bool (Operation *)> visit = [&](Operation *op) -> bool {
2483
- // No need to traverse any IR if the op was already marked as live.
2484
- if (live.contains (op))
2485
- return true ;
2486
-
2487
- // Do not visit ops multiple times. If we find a circle, no live user was
2488
- // found on the current path.
2489
- if (visited.contains (op))
2490
- return false ;
2491
- visited.insert (op);
2492
-
2493
- // Visit all users.
2494
- for (Operation *user : op->getUsers ()) {
2495
- // If the user is not an unrealized_conversion_cast op, then the given op
2496
- // is live.
2497
- if (!castOpSet.contains (user)) {
2498
- live.insert (op);
2499
- return true ;
2500
- }
2501
- // Otherwise, it is live if a live op can be reached from one of its
2502
- // users (which must all be unrealized_conversion_cast ops).
2503
- if (visit (user)) {
2504
- live.insert (op);
2505
- return true ;
2506
- }
2507
- }
2508
-
2509
- return false ;
2510
- };
2511
-
2512
- // Visit all cast ops.
2513
- for (UnrealizedConversionCastOp op : castOps) {
2514
- visit (op);
2515
- visited.clear ();
2516
- }
2517
-
2518
- // Erase all cast ops that are dead.
2519
- for (UnrealizedConversionCastOp op : castOps) {
2520
- if (live.contains (op)) {
2521
- if (remainingCastOps)
2522
- remainingCastOps->push_back (op);
2523
- continue ;
2524
- }
2525
- op->dropAllUses ();
2526
- op->erase ();
2527
- }
2528
- }
2529
-
2530
2469
LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2531
2470
if (ops.empty ())
2532
2471
return success ();
@@ -2585,14 +2524,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2585
2524
// Reconcile all UnrealizedConversionCastOps that were inserted by the
2586
2525
// dialect conversion frameworks. (Not the one that were inserted by
2587
2526
// patterns.)
2588
- SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
2589
- eraseDeadUnrealizedCasts (allCastOps, &remainingCastOps1);
2590
- reconcileUnrealizedCasts (remainingCastOps1, &remainingCastOps2);
2527
+ SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2528
+ reconcileUnrealizedCasts (allCastOps, &remainingCastOps);
2591
2529
2592
2530
// Try to legalize all unresolved materializations.
2593
2531
if (config.buildMaterializations ) {
2594
2532
IRRewriter rewriter (rewriterImpl.context , config.listener );
2595
- for (UnrealizedConversionCastOp castOp : remainingCastOps2 ) {
2533
+ for (UnrealizedConversionCastOp castOp : remainingCastOps ) {
2596
2534
auto it = rewriteMap.find (castOp.getOperation ());
2597
2535
assert (it != rewriteMap.end () && " inconsistent state" );
2598
2536
if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
@@ -2651,26 +2589,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2651
2589
continue ;
2652
2590
Operation *op = opReplacement->getOperation ();
2653
2591
for (OpResult result : op->getResults ()) {
2654
- Value newValue = rewriterImpl.mapping .lookupOrNull (result);
2655
-
2656
- // If the operation result was replaced with null, all of the uses of this
2657
- // value should be replaced.
2658
- if (!newValue) {
2659
- if (failed (legalizeErasedResult (op, result, rewriterImpl)))
2660
- return failure ();
2661
- continue ;
2662
- }
2663
-
2664
- // Otherwise, check to see if the type of the result changed.
2665
- if (result.getType () == newValue.getType ())
2592
+ // If the type of this op result changed and the result is still live,
2593
+ // we need to materialize a conversion.
2594
+ if (rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
2666
2595
continue ;
2667
-
2668
2596
Operation *liveUser =
2669
2597
findLiveUserOfReplaced (result, rewriterImpl, inverseMapping);
2670
2598
if (!liveUser)
2671
2599
continue ;
2672
2600
2673
2601
// Legalize this result.
2602
+ Value newValue = rewriterImpl.mapping .lookupOrNull (result);
2603
+ assert (newValue && " replacement value not found" );
2674
2604
Value castValue = rewriterImpl.buildUnresolvedMaterialization (
2675
2605
MaterializationKind::Source, computeInsertPoint (result), op->getLoc (),
2676
2606
/* inputs=*/ newValue, /* outputType=*/ result.getType (),
@@ -2728,25 +2658,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2728
2658
return success ();
2729
2659
}
2730
2660
2731
- LogicalResult OperationConverter::legalizeErasedResult (
2732
- Operation *op, OpResult result,
2733
- ConversionPatternRewriterImpl &rewriterImpl) {
2734
- // If the operation result was replaced with null, all of the uses of this
2735
- // value should be replaced.
2736
- auto liveUserIt = llvm::find_if_not (result.getUsers (), [&](Operation *user) {
2737
- return rewriterImpl.isOpIgnored (user);
2738
- });
2739
- if (liveUserIt != result.user_end ()) {
2740
- InFlightDiagnostic diag = op->emitError (" failed to legalize operation '" )
2741
- << op->getName () << " ' marked as erased" ;
2742
- diag.attachNote (liveUserIt->getLoc ())
2743
- << " found live user of result #" << result.getResultNumber () << " : "
2744
- << *liveUserIt;
2745
- return failure ();
2746
- }
2747
- return success ();
2748
- }
2749
-
2750
2661
// ===----------------------------------------------------------------------===//
2751
2662
// Reconcile Unrealized Casts
2752
2663
// ===----------------------------------------------------------------------===//
0 commit comments