Skip to content

Commit 28a11cc

Browse files
authored
Revert "Fix block merging" (#97460)
Reverts #96871 Bots are broken.
1 parent b8eaa5b commit 28a11cc

12 files changed

+93
-289
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,15 +463,10 @@ struct BufferDeallocationSimplificationPass
463463
SplitDeallocWhenNotAliasingAnyOther,
464464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465465
analysis);
466-
// We don't want that the block structure changes invalidating the
467-
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
468-
// region simplification
469-
GreedyRewriteConfig config;
470-
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
471466
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
472467

473-
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
474-
config)))
468+
if (failed(
469+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
475470
signalPassFailure();
476471
}
477472
};

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 12 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,18 @@
99
#include "mlir/Transforms/RegionUtils.h"
1010
#include "mlir/Analysis/TopologicalSortUtils.h"
1111
#include "mlir/IR/Block.h"
12-
#include "mlir/IR/BuiltinOps.h"
1312
#include "mlir/IR/IRMapping.h"
1413
#include "mlir/IR/Operation.h"
1514
#include "mlir/IR/PatternMatch.h"
1615
#include "mlir/IR/RegionGraphTraits.h"
1716
#include "mlir/IR/Value.h"
1817
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1918
#include "mlir/Interfaces/SideEffectInterfaces.h"
20-
#include "mlir/Support/LogicalResult.h"
2119

2220
#include "llvm/ADT/DepthFirstIterator.h"
2321
#include "llvm/ADT/PostOrderIterator.h"
24-
#include "llvm/ADT/STLExtras.h"
25-
#include "llvm/ADT/SmallSet.h"
2622

2723
#include <deque>
28-
#include <iterator>
2924

3025
using namespace mlir;
3126

@@ -704,8 +699,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
704699
blockIterators.push_back(mergeBlock->begin());
705700

706701
// Update each of the predecessor terminators with the new arguments.
707-
SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
708-
SmallVector<Value, 8>());
702+
SmallVector<SmallVector<Value, 8>, 2> newArguments(
703+
1 + blocksToMerge.size(),
704+
SmallVector<Value, 8>(operandsToMerge.size()));
709705
unsigned curOpIndex = 0;
710706
for (const auto &it : llvm::enumerate(operandsToMerge)) {
711707
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -716,22 +712,13 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
716712
Block::iterator &blockIter = blockIterators[i];
717713
std::advance(blockIter, nextOpOffset);
718714
auto &operand = blockIter->getOpOperand(it.value().second);
719-
Value operandVal = operand.get();
720-
Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
721-
operandVal);
722-
if (it == newArguments[i].end()) {
723-
newArguments[i].push_back(operandVal);
724-
// Update the operand and insert an argument if this is the leader.
725-
if (i == 0) {
726-
operand.set(leaderBlock->addArgument(operandVal.getType(),
727-
operandVal.getLoc()));
728-
}
729-
} else if (i == 0) {
730-
// If this is the leader, update the operand but do not insert a new
731-
// argument. Instead, the opearand should point to one of the
732-
// arguments we already passed (and that contained `operandVal`)
733-
operand.set(leaderBlock->getArgument(
734-
std::distance(newArguments[i].begin(), it)));
715+
newArguments[i][it.index()] = operand.get();
716+
717+
// Update the operand and insert an argument if this is the leader.
718+
if (i == 0) {
719+
Value operandVal = operand.get();
720+
operand.set(leaderBlock->addArgument(operandVal.getType(),
721+
operandVal.getLoc()));
735722
}
736723
}
737724
}
@@ -831,109 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
831818
return success(anyChanged);
832819
}
833820

834-
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
835-
Block &block) {
836-
SmallVector<size_t> argsToErase;
837-
838-
// Go through the arguments of the block
839-
for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
840-
bool sameArg = true;
841-
Value commonValue;
842-
843-
// Go through the block predecessor and flag if they pass to the block
844-
// different values for the same argument
845-
for (auto predIt = block.pred_begin(), predE = block.pred_end();
846-
predIt != predE; ++predIt) {
847-
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
848-
if (!branch) {
849-
sameArg = false;
850-
break;
851-
}
852-
unsigned succIndex = predIt.getSuccessorIndex();
853-
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
854-
auto operands = succOperands.getForwardedOperands();
855-
if (!commonValue) {
856-
commonValue = operands[argIdx];
857-
} else {
858-
if (operands[argIdx] != commonValue) {
859-
sameArg = false;
860-
break;
861-
}
862-
}
863-
}
864-
865-
// If they are passing the same value, drop the argument
866-
if (commonValue && sameArg) {
867-
argsToErase.push_back(argIdx);
868-
869-
// Remove the argument from the block
870-
Value argVal = block.getArgument(argIdx);
871-
rewriter.replaceAllUsesWith(argVal, commonValue);
872-
}
873-
}
874-
875-
// Remove the arguments
876-
for (auto argIdx : llvm::reverse(argsToErase)) {
877-
block.eraseArgument(argIdx);
878-
879-
// Remove the argument from the branch ops
880-
for (auto predIt = block.pred_begin(), predE = block.pred_end();
881-
predIt != predE; ++predIt) {
882-
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
883-
unsigned succIndex = predIt.getSuccessorIndex();
884-
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
885-
succOperands.erase(argIdx);
886-
}
887-
}
888-
return success(!argsToErase.empty());
889-
}
890-
891-
/// This optimization drops redundant argument to blocks. I.e., if a given
892-
/// argument to a block receives the same value from each of the block
893-
/// predecessors, we can remove the argument from the block and use directly the
894-
/// original value. This is a simple example:
895-
///
896-
/// %cond = llvm.call @rand() : () -> i1
897-
/// %val0 = llvm.mlir.constant(1 : i64) : i64
898-
/// %val1 = llvm.mlir.constant(2 : i64) : i64
899-
/// %val2 = llvm.mlir.constant(3 : i64) : i64
900-
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901-
/// : i64)
902-
///
903-
/// ^bb1(%arg0 : i64, %arg1 : i64):
904-
/// llvm.call @foo(%arg0, %arg1)
905-
///
906-
/// The previous IR can be rewritten as:
907-
/// %cond = llvm.call @rand() : () -> i1
908-
/// %val0 = llvm.mlir.constant(1 : i64) : i64
909-
/// %val1 = llvm.mlir.constant(2 : i64) : i64
910-
/// %val2 = llvm.mlir.constant(3 : i64) : i64
911-
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
912-
///
913-
/// ^bb1(%arg0 : i64):
914-
/// llvm.call @foo(%val0, %arg0)
915-
///
916-
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
917-
MutableArrayRef<Region> regions) {
918-
llvm::SmallSetVector<Region *, 1> worklist;
919-
for (auto &region : regions)
920-
worklist.insert(&region);
921-
bool anyChanged = false;
922-
while (!worklist.empty()) {
923-
Region *region = worklist.pop_back_val();
924-
925-
// Add any nested regions to the worklist.
926-
for (Block &block : *region) {
927-
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
928-
929-
for (auto &op : block)
930-
for (auto &nestedRegion : op.getRegions())
931-
worklist.insert(&nestedRegion);
932-
}
933-
}
934-
return success(anyChanged);
935-
}
936-
937821
//===----------------------------------------------------------------------===//
938822
// Region Simplification
939823
//===----------------------------------------------------------------------===//
@@ -948,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
948832
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
949833
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
950834
bool mergedIdenticalBlocks = false;
951-
bool droppedRedundantArguments = false;
952-
if (mergeBlocks) {
835+
if (mergeBlocks)
953836
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
954-
droppedRedundantArguments =
955-
succeeded(dropRedundantArguments(rewriter, regions));
956-
}
957837
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
958-
mergedIdenticalBlocks || droppedRedundantArguments);
838+
mergedIdenticalBlocks);
959839
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -178,32 +178,28 @@ func.func @condBranchDynamicTypeNested(
178178
// CHECK-NEXT: ^bb1
179179
// CHECK-NOT: bufferization.dealloc
180180
// CHECK-NOT: bufferization.clone
181-
// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
181+
// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
182182
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
183183
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
184184
// CHECK-NEXT: test.buffer_based
185185
// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
186186
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
187187
// CHECK-NOT: bufferization.dealloc
188188
// CHECK-NOT: bufferization.clone
189-
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
189+
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
190190
// CHECK-NEXT: ^bb3:
191191
// CHECK-NOT: bufferization.dealloc
192192
// CHECK-NOT: bufferization.clone
193-
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
194-
// CHECK-NEXT: ^bb4:
193+
// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
194+
// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
195195
// CHECK-NOT: bufferization.dealloc
196196
// CHECK-NOT: bufferization.clone
197-
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
198-
// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
199-
// CHECK-NOT: bufferization.dealloc
200-
// CHECK-NOT: bufferization.clone
201-
// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
202-
// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
197+
// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
198+
// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
203199
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
204200
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
205-
// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
206-
// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
201+
// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
202+
// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
207203
// CHECK: test.copy
208204
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
209205
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])

mlir/test/Dialect/Linalg/detensorize_entry_block.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
1515
// CHECK-LABEL: @main
1616
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
1717
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
18-
// CHECK: cf.br ^{{.*}}
19-
// CHECK: ^{{.*}}:
20-
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
18+
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
19+
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
20+
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
2121
// CHECK: return %[[ELEMENTS]] : tensor<f32>

mlir/test/Dialect/Linalg/detensorize_if.mlir

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ func.func @main() -> (tensor<i32>) attributes {} {
4242
}
4343

4444
// CHECK-LABEL: func @main()
45-
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
46-
// CHECK-DAG: arith.constant true
47-
// CHECK: cf.br
48-
// CHECK-NEXT: ^[[bb1:.*]]:
49-
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
50-
// CHECK-NEXT: ^[[bb2]]
51-
// CHECK-NEXT: cf.br ^[[bb3:.*]]
52-
// CHECK-NEXT: ^[[bb3]]
53-
// CHECK-NEXT: return %[[cst]]
45+
// CHECK-DAG: arith.constant 0
46+
// CHECK-DAG: arith.constant 10
47+
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
48+
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
49+
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
50+
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
51+
// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
52+
// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
53+
// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
54+
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
55+
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
56+
// CHECK-NEXT: return %{{.*}}
5457
// CHECK-NEXT: }
5558

5659
// -----
@@ -103,17 +106,20 @@ func.func @main() -> (tensor<i32>) attributes {} {
103106
}
104107

105108
// CHECK-LABEL: func @main()
106-
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
107-
// CHECK-DAG: arith.constant true
108-
// CHECK: cf.br ^[[bb1:.*]]
109-
// CHECK-NEXT: ^[[bb1:.*]]:
110-
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
111-
// CHECK-NEXT: ^[[bb2]]:
112-
// CHECK-NEXT: cf.br ^[[bb3:.*]]
113-
// CHECK-NEXT: ^[[bb3]]:
114-
// CHECK-NEXT: cf.br ^[[bb4:.*]]
115-
// CHECK-NEXT: ^[[bb4]]:
116-
// CHECK-NEXT: return %[[cst]]
109+
// CHECK-DAG: arith.constant 0
110+
// CHECK-DAG: arith.constant 10
111+
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
112+
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
113+
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
114+
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
115+
// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
116+
// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
117+
// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
118+
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
119+
// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32)
120+
// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
121+
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
122+
// CHECK-NEXT: return %{{.*}}
117123
// CHECK-NEXT: }
118124

119125
// -----
@@ -165,13 +171,16 @@ func.func @main() -> (tensor<i32>) attributes {} {
165171
}
166172

167173
// CHECK-LABEL: func @main()
168-
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10>
169-
// CHECK-DAG: arith.constant true
170-
// CHECK: cf.br ^[[bb1:.*]]
171-
// CHECK-NEXT: ^[[bb1]]:
172-
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
173-
// CHECK-NEXT: ^[[bb2]]
174-
// CHECK-NEXT: cf.br ^[[bb3:.*]]
175-
// CHECK-NEXT: ^[[bb3]]
176-
// CHECK-NEXT: return %[[cst]]
174+
// CHECK-DAG: arith.constant 0
175+
// CHECK-DAG: arith.constant 10
176+
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
177+
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
178+
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
179+
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
180+
// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
181+
// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
182+
// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
183+
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
184+
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
185+
// CHECK-NEXT: return %{{.*}}
177186
// CHECK-NEXT: }

mlir/test/Dialect/Linalg/detensorize_while.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
4646
// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32)
4747
// DET-ALL: ^[[bb1]](%{{.*}}: i32)
4848
// DET-ALL: arith.cmpi slt, {{.*}}
49-
// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
50-
// DET-ALL: ^[[bb2]]
49+
// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
50+
// DET-ALL: ^[[bb2]](%{{.*}}: i32)
5151
// DET-ALL: arith.addi {{.*}}
5252
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32)
53-
// DET-ALL: ^[[bb3]]:
53+
// DET-ALL: ^[[bb3]](%{{.*}}: i32)
5454
// DET-ALL: tensor.from_elements {{.*}}
5555
// DET-ALL: return %{{.*}} : tensor<i32>
5656

@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
6262
// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32)
6363
// DET-CF: ^[[bb1]](%{{.*}}: i32)
6464
// DET-CF: arith.cmpi slt, {{.*}}
65-
// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
66-
// DET-CF: ^[[bb2]]:
65+
// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
66+
// DET-CF: ^[[bb2]](%{{.*}}: i32)
6767
// DET-CF: arith.addi {{.*}}
6868
// DET-CF: cf.br ^[[bb1]](%{{.*}} : i32)
69-
// DET-CF: ^[[bb3]]:
69+
// DET-CF: ^[[bb3]](%{{.*}}: i32)
7070
// DET-CF: tensor.from_elements %{{.*}} : tensor<i32>
7171
// DET-CF: return %{{.*}} : tensor<i32>

0 commit comments

Comments
 (0)