Skip to content

[mlir] Fix block merging #102038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
// We don't want that the block structure changes invalidating the
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
// region simplification
GreedyRewriteConfig config;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config)))
signalPassFailure();
}
};
Expand Down
212 changes: 210 additions & 2 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"

#include <deque>
#include <iterator>

using namespace mlir;

Expand Down Expand Up @@ -674,6 +679,94 @@ static bool ableToUpdatePredOperands(Block *block) {
return true;
}

/// Prunes the redundant list of new arguments. E.g., if we are passing an
/// argument list like [x, y, z, x] this would return [x, y, z] and it would
/// update the `block` (to whom the argument are passed to) accordingly. The new
/// arguments are passed as arguments at the back of the block, hence we need to
/// know how many `numOldArguments` were before, in order to correctly replace
/// the new arguments in the block
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
RewriterBase &rewriter, unsigned numOldArguments, Block *block) {

SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
newArguments.size(), SmallVector<Value, 8>());

if (newArguments.empty())
return newArguments;

// `newArguments` is a 2D array of size `numLists` x `numArgs`
unsigned numLists = newArguments.size();
unsigned numArgs = newArguments[0].size();

// Map that for each arg index contains the index that we can use in place of
// the original index. E.g., if we have newArgs = [x, y, z, x], we will have
// idxToReplacement[3] = 0
llvm::DenseMap<unsigned, unsigned> idxToReplacement;

// This is a useful data structure to track the first appearance of a Value
// on a given list of arguments
DenseMap<Value, unsigned> firstValueToIdx;
for (unsigned j = 0; j < numArgs; ++j) {
Value newArg = newArguments[0][j];
if (!firstValueToIdx.contains(newArg))
firstValueToIdx[newArg] = j;
}

// Go through the first list of arguments (list 0).
for (unsigned j = 0; j < numArgs; ++j) {
// Look back to see if there are possible redundancies in list 0. Please
// note that we are using a map to annotate when an argument was seen first
// to avoid a O(N^2) algorithm. This has the drawback that if we have two
// lists like:
// list0: [%a, %a, %a]
// list1: [%c, %b, %b]
// We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
// point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
// the number of arguments can be potentially unbounded we cannot afford a
// O(N^2) algorithm (to search to all the possible pairs) and we need to
// accept the trade-off.
unsigned k = firstValueToIdx[newArguments[0][j]];
if (k == j)
continue;

bool shouldReplaceJ = true;
unsigned replacement = k;
// If a possible redundancy is found, then scan the other lists: we
// can prune the arguments if and only if they are redundant in every
// list.
for (unsigned i = 1; i < numLists; ++i)
shouldReplaceJ =
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
// Save the replacement.
if (shouldReplaceJ)
idxToReplacement[j] = replacement;
}

// Populate the pruned argument list.
for (unsigned i = 0; i < numLists; ++i)
for (unsigned j = 0; j < numArgs; ++j)
if (!idxToReplacement.contains(j))
newArgumentsPruned[i].push_back(newArguments[i][j]);

// Replace the block's redundant arguments.
SmallVector<unsigned> toErase;
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
if (idxToReplacement.contains(idx)) {
Value oldArg = block->getArgument(numOldArguments + idx);
Value newArg =
block->getArgument(numOldArguments + idxToReplacement[idx]);
rewriter.replaceAllUsesWith(oldArg, newArg);
toErase.push_back(numOldArguments + idx);
}
}

// Erase the block's redundant arguments.
for (unsigned idxToErase : llvm::reverse(toErase))
block->eraseArgument(idxToErase);
return newArgumentsPruned;
}

LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
// Don't consider clusters that don't have blocks to merge.
if (blocksToMerge.empty())
Expand Down Expand Up @@ -703,6 +796,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
1 + blocksToMerge.size(),
SmallVector<Value, 8>(operandsToMerge.size()));
unsigned curOpIndex = 0;
unsigned numOldArguments = leaderBlock->getNumArguments();
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
curOpIndex = it.value().first;
Expand All @@ -722,6 +816,11 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
}
}
}

// Prune redundant arguments and update the leader block argument list
newArguments = pruneRedundantArguments(newArguments, rewriter,
numOldArguments, leaderBlock);

// Update the predecessors for each of the blocks.
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
for (auto predIt = block->pred_begin(), predE = block->pred_end();
Expand Down Expand Up @@ -818,6 +917,111 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}

/// If a block's argument is always the same across different invocations, then
/// drop the argument and use the value directly inside the block
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
Block &block) {
SmallVector<size_t> argsToErase;

// Go through the arguments of the block.
for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
bool sameArg = true;
Value commonValue;

// Go through the block predecessor and flag if they pass to the block
// different values for the same argument.
for (Block::pred_iterator predIt = block.pred_begin(),
predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
if (!branch) {
sameArg = false;
break;
}
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
auto branchOperands = succOperands.getForwardedOperands();
if (!commonValue) {
commonValue = branchOperands[argIdx];
continue;
}
if (branchOperands[argIdx] != commonValue) {
sameArg = false;
break;
}
}

// If they are passing the same value, drop the argument.
if (commonValue && sameArg) {
argsToErase.push_back(argIdx);

// Remove the argument from the block.
rewriter.replaceAllUsesWith(blockOperand, commonValue);
}
}

// Remove the arguments.
for (size_t argIdx : llvm::reverse(argsToErase)) {
block.eraseArgument(argIdx);

// Remove the argument from the branch ops.
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
succOperands.erase(argIdx);
}
}
return success(!argsToErase.empty());
}

/// This optimization drops redundant argument to blocks. I.e., if a given
/// argument to a block receives the same value from each of the block
/// predecessors, we can remove the argument from the block and use directly the
/// original value. This is a simple example:
///
/// %cond = llvm.call @rand() : () -> i1
/// %val0 = llvm.mlir.constant(1 : i64) : i64
/// %val1 = llvm.mlir.constant(2 : i64) : i64
/// %val2 = llvm.mlir.constant(3 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
/// : i64)
///
/// ^bb1(%arg0 : i64, %arg1 : i64):
/// llvm.call @foo(%arg0, %arg1)
///
/// The previous IR can be rewritten as:
/// %cond = llvm.call @rand() : () -> i1
/// %val0 = llvm.mlir.constant(1 : i64) : i64
/// %val1 = llvm.mlir.constant(2 : i64) : i64
/// %val2 = llvm.mlir.constant(3 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
///
/// ^bb1(%arg0 : i64):
/// llvm.call @foo(%val0, %arg0)
///
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
llvm::SmallSetVector<Region *, 1> worklist;
for (Region &region : regions)
worklist.insert(&region);
bool anyChanged = false;
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();

// Add any nested regions to the worklist.
for (Block &block : *region) {
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
Copy link
Contributor

@Hardcode84 Hardcode84 Aug 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anyChanged is completely overwritten on each iteration.

It should be anyChanged = succeeded(dropRedundantArguments(rewriter, block)) || anyChanged, I believe


for (Operation &op : block)
for (Region &nestedRegion : op.getRegions())
worklist.insert(&nestedRegion);
}
}
return success(anyChanged);
}

//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
Expand All @@ -832,8 +1036,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
if (mergeBlocks)
bool droppedRedundantArguments = false;
if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
droppedRedundantArguments =
succeeded(dropRedundantArguments(rewriter, regions));
}
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
mergedIdenticalBlocks || droppedRedundantArguments);
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,28 +178,32 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: ^bb1
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
// CHECK-NEXT: ^bb4:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: test.copy
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
// CHECK: cf.br ^{{.*}}
// CHECK: ^{{.*}}:
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
Loading
Loading