Skip to content

Commit e018af0

Browse files
committed
Offset the new arguments by the number of old arguments
1 parent 746f507 commit e018af0

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -679,12 +679,15 @@ static bool ableToUpdatePredOperands(Block *block) {
679679
return true;
680680
}
681681

682-
/// Prunes the redundant list of arguments. E.g., if we are passing an argument
683-
/// list like [x, y, z, x] this would return [x, y, z] and it would update the
684-
/// `block` (to whom the argument are passed to) accordingly.
682+
/// Prunes the redundant list of new arguments. E.g., if we are passing an
683+
/// argument list like [x, y, z, x] this would return [x, y, z] and it would
684+
/// update the `block` (to whom the argument are passed to) accordingly. The new
685+
/// arguments are passed as arguments at the back of the block, hence we need to
686+
/// know how many `numOldArguments` were before, in order to correctly replace
687+
/// the new arguments in the block
685688
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
686689
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
687-
RewriterBase &rewriter, Block *block) {
690+
RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
688691

689692
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
690693
newArguments.size(), SmallVector<Value, 8>());
@@ -751,10 +754,11 @@ static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
751754
SmallVector<unsigned> toErase;
752755
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
753756
if (idxToReplacement.contains(idx)) {
754-
Value oldArg = block->getArgument(idx);
755-
Value newArg = block->getArgument(idxToReplacement[idx]);
757+
Value oldArg = block->getArgument(numOldArguments + idx);
758+
Value newArg =
759+
block->getArgument(numOldArguments + idxToReplacement[idx]);
756760
rewriter.replaceAllUsesWith(oldArg, newArg);
757-
toErase.push_back(idx);
761+
toErase.push_back(numOldArguments + idx);
758762
}
759763
}
760764

@@ -793,6 +797,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
793797
1 + blocksToMerge.size(),
794798
SmallVector<Value, 8>(operandsToMerge.size()));
795799
unsigned curOpIndex = 0;
800+
unsigned numOldArguments = leaderBlock->getNumArguments();
796801
for (const auto &it : llvm::enumerate(operandsToMerge)) {
797802
unsigned nextOpOffset = it.value().first - curOpIndex;
798803
curOpIndex = it.value().first;
@@ -814,7 +819,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
814819
}
815820

816821
// Prune redundant arguments and update the leader block argument list
817-
newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
822+
newArguments = pruneRedundantArguments(newArguments, rewriter,
823+
numOldArguments, leaderBlock);
818824

819825
// Update the predecessors for each of the blocks.
820826
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {

mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,33 @@ llvm.func @redundant_args_complex(%cond : i1) {
160160
^bb3:
161161
llvm.return
162162
}
163+
164+
llvm.func @blocks_with_args() {
165+
%0 = llvm.mlir.zero : !llvm.ptr
166+
%1 = llvm.call @rand() : () -> i1
167+
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : i64)
168+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64)
169+
// CHECK: %[[cond:.*]] = llvm.call @rand
170+
%3 = llvm.mlir.constant(0) : i64
171+
%4 = llvm.mlir.constant(1) : i64
172+
// CHECK: llvm.cond_br %[[cond]], ^bb1(%[[c1]] : i64), ^bb1(%[[c0]] : i64)
173+
// CHECK: ^bb1(%{{.*}}: i64):
174+
// CHECK ^bb2:
175+
// CHECK ^bb3:
176+
// CHECK llvm.return
177+
llvm.cond_br %1, ^bb7(%0 : !llvm.ptr), ^bb1(%0 : !llvm.ptr)
178+
^bb1(%5: !llvm.ptr):
179+
llvm.store %5, %0 : !llvm.ptr, !llvm.ptr
180+
llvm.cond_br %1, ^bb2(%3 : i64), ^bb4(%3 : i64)
181+
^bb7(%6: !llvm.ptr):
182+
llvm.store %6, %0 : !llvm.ptr, !llvm.ptr
183+
llvm.cond_br %1, ^bb2(%4 : i64), ^bb4(%4 : i64)
184+
^bb2(%7: i64):
185+
llvm.call @foo(%7) : (i64) -> ()
186+
llvm.br ^bb8
187+
^bb4(%8: i64):
188+
llvm.call @foo(%8) : (i64) -> ()
189+
llvm.br ^bb8
190+
^bb8:
191+
llvm.return
192+
}

0 commit comments

Comments
 (0)