Skip to content

Fix block merging #96871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2024
Merged

Fix block merging #96871

merged 4 commits into from
Jul 2, 2024

Conversation

giuseros
Copy link
Contributor

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design.

@giuseros giuseros requested review from joker-eph and Mogball June 27, 2024 08:55
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir mlir:bufferization Bufferization infrastructure labels Jun 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 27, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Giuseppe Rossini (giuseros)

Changes

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design.


Full diff: https://github.com/llvm/llvm-project/pull/96871.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+127-12)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..31c4bfe1f6b2f 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(
-        1 + blocksToMerge.size(),
-        SmallVector<Value, 8>(operandsToMerge.size()));
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+                                                       SmallVector<Value, 8>());
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        newArguments[i][it.index()] = operand.get();
-
-        // Update the operand and insert an argument if this is the leader.
-        if (i == 0) {
-          Value operandVal = operand.get();
-          operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                               operandVal.getLoc()));
+        Value operandVal = operand.get();
+        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+                              operandVal);
+        if (it == newArguments[i].end()) {
+          newArguments[i].push_back(operandVal);
+          // Update the operand and insert an argument if this is the leader.
+          if (i == 0) {
+            operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                                 operandVal.getLoc()));
+          }
+        } else if (i == 0) {
+          // If this is the leader, update the operand but do not insert a new
+          // argument. Instead, the opearand should point to one of the
+          // arguments we already passed (and that contained `operandVal`)
+          operand.set(leaderBlock->getArgument(
+              std::distance(newArguments[i].begin(), it)));
         }
       }
     }
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block
+  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto operands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = operands[argIdx];
+      } else {
+        if (operands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block
+      Value argVal = block.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(argVal, commonValue);
+    }
+  }
+
+  // Remove the arguments
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val = llvm.mlir.constant(1 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg1)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for looking into this issue! Can you add a test case to the PR?

@giuseros
Copy link
Contributor Author

Hi @Mogball , thanks for the comment! I added a test case and fixed mlir/test/Transforms/canonicalize-block-merge.mlir .

The test case I added is a simplification of the one you proposed here. The point is that in Triton we meet exactly this case, and indeed we end up hanging the compiler.

@Mogball
Copy link
Contributor

Mogball commented Jun 28, 2024

Thank you so much for fixing this!

@giuseros
Copy link
Contributor Author

giuseros commented Jul 2, 2024

Hi @Mogball , is it ok if I merge this?

@Mogball
Copy link
Contributor

Mogball commented Jul 2, 2024

do it!

@giuseros giuseros merged commit 6c3897d into llvm:main Jul 2, 2024
7 checks passed
@giuseros
Copy link
Contributor Author

giuseros commented Jul 2, 2024

Thank you so much! It's the first non-negligible PR I merge upstream, so I was not sure :P

@joker-eph
Copy link
Collaborator

Seems like this broke the bots: https://lab.llvm.org/buildbot/#/builders/116/builds/778 ; I'll revert for now.

joker-eph added a commit that referenced this pull request Jul 2, 2024
joker-eph added a commit that referenced this pull request Jul 2, 2024
@giuseros
Copy link
Contributor Author

giuseros commented Jul 3, 2024

Hi @joker-eph , I am very sorry about this. Thanks for reverting the patch. I have a question: to test my patch I ran ninja check-mlir. Is there a more extensive set of tests I should run?

@joker-eph
Copy link
Collaborator

The failing tests are enabled by -DMLIR_INCLUDE_INTEGRATION_TESTS=ON I believe

lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
With this PR I am trying to address:
llvm#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note**: many tests are still not passing. But I wanted to submit the
code before changing all the tests (and probably adding a couple), so
that we can agree in principle on the algorithm/design.
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
With this PR I am trying to address:
llvm#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note**: many tests are still not passing. But I wanted to submit the
code before changing all the tests (and probably adding a couple), so
that we can agree in principle on the algorithm/design.
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
giuseros added a commit that referenced this pull request Jul 17, 2024
With this PR I am trying to address:
#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note**: this a rework of #96871 . I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
With this PR I am trying to address:
#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note**: this a rework of #96871 . I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250916
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants