9
9
#include " mlir/Transforms/RegionUtils.h"
10
10
#include " mlir/Analysis/TopologicalSortUtils.h"
11
11
#include " mlir/IR/Block.h"
12
- #include " mlir/IR/BuiltinOps.h"
13
12
#include " mlir/IR/IRMapping.h"
14
13
#include " mlir/IR/Operation.h"
15
14
#include " mlir/IR/PatternMatch.h"
16
15
#include " mlir/IR/RegionGraphTraits.h"
17
16
#include " mlir/IR/Value.h"
18
17
#include " mlir/Interfaces/ControlFlowInterfaces.h"
19
18
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
- #include " mlir/Support/LogicalResult.h"
21
19
22
20
#include " llvm/ADT/DepthFirstIterator.h"
23
21
#include " llvm/ADT/PostOrderIterator.h"
24
- #include " llvm/ADT/STLExtras.h"
25
- #include " llvm/ADT/SmallSet.h"
26
22
27
23
#include < deque>
28
- #include < iterator>
29
24
30
25
using namespace mlir ;
31
26
@@ -704,8 +699,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
704
699
blockIterators.push_back (mergeBlock->begin ());
705
700
706
701
// Update each of the predecessor terminators with the new arguments.
707
- SmallVector<SmallVector<Value, 8 >, 2 > newArguments (1 + blocksToMerge.size (),
708
- SmallVector<Value, 8 >());
702
+ SmallVector<SmallVector<Value, 8 >, 2 > newArguments (
703
+ 1 + blocksToMerge.size (),
704
+ SmallVector<Value, 8 >(operandsToMerge.size ()));
709
705
unsigned curOpIndex = 0 ;
710
706
for (const auto &it : llvm::enumerate (operandsToMerge)) {
711
707
unsigned nextOpOffset = it.value ().first - curOpIndex;
@@ -716,22 +712,13 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
716
712
Block::iterator &blockIter = blockIterators[i];
717
713
std::advance (blockIter, nextOpOffset);
718
714
auto &operand = blockIter->getOpOperand (it.value ().second );
719
- Value operandVal = operand.get ();
720
- Value *it = std::find (newArguments[i].begin (), newArguments[i].end (),
721
- operandVal);
722
- if (it == newArguments[i].end ()) {
723
- newArguments[i].push_back (operandVal);
724
- // Update the operand and insert an argument if this is the leader.
725
- if (i == 0 ) {
726
- operand.set (leaderBlock->addArgument (operandVal.getType (),
727
- operandVal.getLoc ()));
728
- }
729
- } else if (i == 0 ) {
730
- // If this is the leader, update the operand but do not insert a new
731
- // argument. Instead, the opearand should point to one of the
732
- // arguments we already passed (and that contained `operandVal`)
733
- operand.set (leaderBlock->getArgument (
734
- std::distance (newArguments[i].begin (), it)));
715
+ newArguments[i][it.index ()] = operand.get ();
716
+
717
+ // Update the operand and insert an argument if this is the leader.
718
+ if (i == 0 ) {
719
+ Value operandVal = operand.get ();
720
+ operand.set (leaderBlock->addArgument (operandVal.getType (),
721
+ operandVal.getLoc ()));
735
722
}
736
723
}
737
724
}
@@ -831,109 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
831
818
return success (anyChanged);
832
819
}
833
820
834
- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
835
- Block &block) {
836
- SmallVector<size_t > argsToErase;
837
-
838
- // Go through the arguments of the block
839
- for (size_t argIdx = 0 ; argIdx < block.getNumArguments (); argIdx++) {
840
- bool sameArg = true ;
841
- Value commonValue;
842
-
843
- // Go through the block predecessor and flag if they pass to the block
844
- // different values for the same argument
845
- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
846
- predIt != predE; ++predIt) {
847
- auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
848
- if (!branch) {
849
- sameArg = false ;
850
- break ;
851
- }
852
- unsigned succIndex = predIt.getSuccessorIndex ();
853
- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
854
- auto operands = succOperands.getForwardedOperands ();
855
- if (!commonValue) {
856
- commonValue = operands[argIdx];
857
- } else {
858
- if (operands[argIdx] != commonValue) {
859
- sameArg = false ;
860
- break ;
861
- }
862
- }
863
- }
864
-
865
- // If they are passing the same value, drop the argument
866
- if (commonValue && sameArg) {
867
- argsToErase.push_back (argIdx);
868
-
869
- // Remove the argument from the block
870
- Value argVal = block.getArgument (argIdx);
871
- rewriter.replaceAllUsesWith (argVal, commonValue);
872
- }
873
- }
874
-
875
- // Remove the arguments
876
- for (auto argIdx : llvm::reverse (argsToErase)) {
877
- block.eraseArgument (argIdx);
878
-
879
- // Remove the argument from the branch ops
880
- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
881
- predIt != predE; ++predIt) {
882
- auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
883
- unsigned succIndex = predIt.getSuccessorIndex ();
884
- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
885
- succOperands.erase (argIdx);
886
- }
887
- }
888
- return success (!argsToErase.empty ());
889
- }
890
-
891
- // / This optimization drops redundant argument to blocks. I.e., if a given
892
- // / argument to a block receives the same value from each of the block
893
- // / predecessors, we can remove the argument from the block and use directly the
894
- // / original value. This is a simple example:
895
- // /
896
- // / %cond = llvm.call @rand() : () -> i1
897
- // / %val0 = llvm.mlir.constant(1 : i64) : i64
898
- // / %val1 = llvm.mlir.constant(2 : i64) : i64
899
- // / %val2 = llvm.mlir.constant(3 : i64) : i64
900
- // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901
- // / : i64)
902
- // /
903
- // / ^bb1(%arg0 : i64, %arg1 : i64):
904
- // / llvm.call @foo(%arg0, %arg1)
905
- // /
906
- // / The previous IR can be rewritten as:
907
- // / %cond = llvm.call @rand() : () -> i1
908
- // / %val0 = llvm.mlir.constant(1 : i64) : i64
909
- // / %val1 = llvm.mlir.constant(2 : i64) : i64
910
- // / %val2 = llvm.mlir.constant(3 : i64) : i64
911
- // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
912
- // /
913
- // / ^bb1(%arg0 : i64):
914
- // / llvm.call @foo(%val0, %arg0)
915
- // /
916
- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
917
- MutableArrayRef<Region> regions) {
918
- llvm::SmallSetVector<Region *, 1 > worklist;
919
- for (auto ®ion : regions)
920
- worklist.insert (®ion);
921
- bool anyChanged = false ;
922
- while (!worklist.empty ()) {
923
- Region *region = worklist.pop_back_val ();
924
-
925
- // Add any nested regions to the worklist.
926
- for (Block &block : *region) {
927
- anyChanged = succeeded (dropRedundantArguments (rewriter, block));
928
-
929
- for (auto &op : block)
930
- for (auto &nestedRegion : op.getRegions ())
931
- worklist.insert (&nestedRegion);
932
- }
933
- }
934
- return success (anyChanged);
935
- }
936
-
937
821
// ===----------------------------------------------------------------------===//
938
822
// Region Simplification
939
823
// ===----------------------------------------------------------------------===//
@@ -948,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
948
832
bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
949
833
bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
950
834
bool mergedIdenticalBlocks = false ;
951
- bool droppedRedundantArguments = false ;
952
- if (mergeBlocks) {
835
+ if (mergeBlocks)
953
836
mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
954
- droppedRedundantArguments =
955
- succeeded (dropRedundantArguments (rewriter, regions));
956
- }
957
837
return success (eliminatedBlocks || eliminatedOpsOrArgs ||
958
- mergedIdenticalBlocks || droppedRedundantArguments );
838
+ mergedIdenticalBlocks);
959
839
}
0 commit comments