-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Fix block merging #96871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix block merging #96871
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Giuseppe Rossini (giuseros) ChangesWith this PR I am trying to address: #63230. What changed:
Note: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design. Full diff: https://github.com/llvm/llvm-project/pull/96871.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+ // We don't want that the block structure changes invalidating the
+ // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // region simplification
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..31c4bfe1f6b2f 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include <deque>
+#include <iterator>
using namespace mlir;
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());
// Update each of the predecessor terminators with the new arguments.
- SmallVector<SmallVector<Value, 8>, 2> newArguments(
- 1 + blocksToMerge.size(),
- SmallVector<Value, 8>(operandsToMerge.size()));
+ SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+ SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
- newArguments[i][it.index()] = operand.get();
-
- // Update the operand and insert an argument if this is the leader.
- if (i == 0) {
- Value operandVal = operand.get();
- operand.set(leaderBlock->addArgument(operandVal.getType(),
- operandVal.getLoc()));
+ Value operandVal = operand.get();
+ Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+ operandVal);
+ if (it == newArguments[i].end()) {
+ newArguments[i].push_back(operandVal);
+ // Update the operand and insert an argument if this is the leader.
+ if (i == 0) {
+ operand.set(leaderBlock->addArgument(operandVal.getType(),
+ operandVal.getLoc()));
+ }
+ } else if (i == 0) {
+ // If this is the leader, update the operand but do not insert a new
+ // argument. Instead, the opearand should point to one of the
+ // arguments we already passed (and that contained `operandVal`)
+ operand.set(leaderBlock->getArgument(
+ std::distance(newArguments[i].begin(), it)));
}
}
}
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ Block &block) {
+ SmallVector<size_t> argsToErase;
+
+ // Go through the arguments of the block
+ for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+ bool sameArg = true;
+ Value commonValue;
+
+ // Go through the block predecessor and flag if they pass to the block
+ // different values for the same argument
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+ if (!branch) {
+ sameArg = false;
+ break;
+ }
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ auto operands = succOperands.getForwardedOperands();
+ if (!commonValue) {
+ commonValue = operands[argIdx];
+ } else {
+ if (operands[argIdx] != commonValue) {
+ sameArg = false;
+ break;
+ }
+ }
+ }
+
+ // If they are passing the same value, drop the argument
+ if (commonValue && sameArg) {
+ argsToErase.push_back(argIdx);
+
+ // Remove the argument from the block
+ Value argVal = block.getArgument(argIdx);
+ rewriter.replaceAllUsesWith(argVal, commonValue);
+ }
+ }
+
+ // Remove the arguments
+ for (auto argIdx : llvm::reverse(argsToErase)) {
+ block.eraseArgument(argIdx);
+
+ // Remove the argument from the branch ops
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ succOperands.erase(argIdx);
+ }
+ }
+ return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+/// llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val = llvm.mlir.constant(1 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+/// ^bb1(%arg0 : i64):
+/// llvm.call @foo(%val0, %arg1)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ MutableArrayRef<Region> regions) {
+ llvm::SmallSetVector<Region *, 1> worklist;
+ for (auto ®ion : regions)
+ worklist.insert(®ion);
+ bool anyChanged = false;
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+
+ // Add any nested regions to the worklist.
+ for (Block &block : *region) {
+ anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+ for (auto &op : block)
+ for (auto &nestedRegion : op.getRegions())
+ worklist.insert(&nestedRegion);
+ }
+ }
+ return success(anyChanged);
+}
+
//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
- if (mergeBlocks)
+ bool droppedRedundantArguments = false;
+ if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+ droppedRedundantArguments =
+ succeeded(dropRedundantArguments(rewriter, regions));
+ }
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
- mergedIdenticalBlocks);
+ mergedIdenticalBlocks || droppedRedundantArguments);
}
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for looking into this issue! Can you add a test case to the PR?
Thank you so much for fixing this! |
Hi @Mogball , is it ok if I merge this? |
do it! |
Thank you so much! It's the first non-negligible PR I merge upstream, so I was not sure :P |
Seems like this broke the bots: https://lab.llvm.org/buildbot/#/builders/116/builds/778 ; I'll revert for now. |
Reverts #96871 Bots are broken.
Hi @joker-eph , I am very sorry about this. Thanks for reverting the patch. I have a question: to test my patch I ran |
The failing tests are enabled by |
With this PR I am trying to address: llvm#63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design.
Reverts llvm#96871 Bots are broken.
With this PR I am trying to address: llvm#63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design.
Reverts llvm#96871 Bots are broken.
With this PR I am trying to address: #63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: this a rework of #96871 . I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
Summary: With this PR I am trying to address: #63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: this a rework of #96871 . I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250916
With this PR I am trying to address: #63230.
What changed:
Value
. The operations operands in the block will point to the argument we already insertedBufferDeallocationSimplification
. The reason, I think, is that the two simplifications are clashing. I.e.,BufferDeallocationSimplification
contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.Note: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design.