Skip to content

Revert "Fix block merging" #97460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2024
Merged

Conversation

joker-eph
Copy link
Collaborator

Reverts #96871

Bots are broken.

@joker-eph joker-eph added the skip-precommit-approval PR for CI feedback, not intended for review label Jul 2, 2024
@joker-eph joker-eph merged commit 28a11cc into main Jul 2, 2024
5 of 6 checks passed
@joker-eph joker-eph deleted the revert-96871-improve_block_merging branch July 2, 2024 18:57
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir mlir:bufferization Bufferization infrastructure labels Jul 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

Reverts llvm/llvm-project#96871

Bots are broken.


Patch is 28.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97460.diff

12 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+2-7)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+12-132)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir (+8-12)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/detensorize_if.mlir (+38-29)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+3-3)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+4-4)
  • (modified) mlir/test/Transforms/make-isolated-from-above.mlir (+9-9)
  • (removed) mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir (-76)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 5227b22653eef..954485cfede3d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,10 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
-    // We don't want that the block structure changes invalidating the
-    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
-    // region simplification
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
-                                            config)))
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 412e2456295ad..4c0f15bafbaba 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -17,15 +16,11 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
-#include <iterator>
 
 using namespace mlir;
 
@@ -704,8 +699,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
-                                                       SmallVector<Value, 8>());
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(
+        1 + blocksToMerge.size(),
+        SmallVector<Value, 8>(operandsToMerge.size()));
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -716,22 +712,13 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        Value operandVal = operand.get();
-        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
-                              operandVal);
-        if (it == newArguments[i].end()) {
-          newArguments[i].push_back(operandVal);
-          // Update the operand and insert an argument if this is the leader.
-          if (i == 0) {
-            operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                                 operandVal.getLoc()));
-          }
-        } else if (i == 0) {
-          // If this is the leader, update the operand but do not insert a new
-          // argument. Instead, the opearand should point to one of the
-          // arguments we already passed (and that contained `operandVal`)
-          operand.set(leaderBlock->getArgument(
-              std::distance(newArguments[i].begin(), it)));
+        newArguments[i][it.index()] = operand.get();
+
+        // Update the operand and insert an argument if this is the leader.
+        if (i == 0) {
+          Value operandVal = operand.get();
+          operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                               operandVal.getLoc()));
         }
       }
     }
@@ -831,109 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
-static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
-                                            Block &block) {
-  SmallVector<size_t> argsToErase;
-
-  // Go through the arguments of the block
-  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
-    bool sameArg = true;
-    Value commonValue;
-
-    // Go through the block predecessor and flag if they pass to the block
-    // different values for the same argument
-    for (auto predIt = block.pred_begin(), predE = block.pred_end();
-         predIt != predE; ++predIt) {
-      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
-      if (!branch) {
-        sameArg = false;
-        break;
-      }
-      unsigned succIndex = predIt.getSuccessorIndex();
-      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      auto operands = succOperands.getForwardedOperands();
-      if (!commonValue) {
-        commonValue = operands[argIdx];
-      } else {
-        if (operands[argIdx] != commonValue) {
-          sameArg = false;
-          break;
-        }
-      }
-    }
-
-    // If they are passing the same value, drop the argument
-    if (commonValue && sameArg) {
-      argsToErase.push_back(argIdx);
-
-      // Remove the argument from the block
-      Value argVal = block.getArgument(argIdx);
-      rewriter.replaceAllUsesWith(argVal, commonValue);
-    }
-  }
-
-  // Remove the arguments
-  for (auto argIdx : llvm::reverse(argsToErase)) {
-    block.eraseArgument(argIdx);
-
-    // Remove the argument from the branch ops
-    for (auto predIt = block.pred_begin(), predE = block.pred_end();
-         predIt != predE; ++predIt) {
-      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
-      unsigned succIndex = predIt.getSuccessorIndex();
-      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      succOperands.erase(argIdx);
-    }
-  }
-  return success(!argsToErase.empty());
-}
-
-/// This optimization drops redundant argument to blocks. I.e., if a given
-/// argument to a block receives the same value from each of the block
-/// predecessors, we can remove the argument from the block and use directly the
-/// original value. This is a simple example:
-///
-/// %cond = llvm.call @rand() : () -> i1
-/// %val0 = llvm.mlir.constant(1 : i64) : i64
-/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(3 : i64) : i64
-/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
-/// : i64)
-///
-/// ^bb1(%arg0 : i64, %arg1 : i64):
-///    llvm.call @foo(%arg0, %arg1)
-///
-/// The previous IR can be rewritten as:
-/// %cond = llvm.call @rand() : () -> i1
-/// %val0 = llvm.mlir.constant(1 : i64) : i64
-/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(3 : i64) : i64
-/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
-///
-/// ^bb1(%arg0 : i64):
-///    llvm.call @foo(%val0, %arg0)
-///
-static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
-                                            MutableArrayRef<Region> regions) {
-  llvm::SmallSetVector<Region *, 1> worklist;
-  for (auto &region : regions)
-    worklist.insert(&region);
-  bool anyChanged = false;
-  while (!worklist.empty()) {
-    Region *region = worklist.pop_back_val();
-
-    // Add any nested regions to the worklist.
-    for (Block &block : *region) {
-      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
-
-      for (auto &op : block)
-        for (auto &nestedRegion : op.getRegions())
-          worklist.insert(&nestedRegion);
-    }
-  }
-  return success(anyChanged);
-}
-
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -948,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  bool droppedRedundantArguments = false;
-  if (mergeBlocks) {
+  if (mergeBlocks)
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
-    droppedRedundantArguments =
-        succeeded(dropRedundantArguments(rewriter, regions));
-  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks || droppedRedundantArguments);
+                 mergedIdenticalBlocks);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 8e14990502143..5e8104f83cc4d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,24 +186,20 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4:
+//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
-//   CHECK-NOT: bufferization.dealloc
-//   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index 50a2d6bf532aa..d1a89226fdb58 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}
-// CHECK: ^{{.*}}:
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
+// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index c728ad21d2209..8d17763c04b6c 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,15 +42,18 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br
-// CHECK-NEXT:   ^[[bb1:.*]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
-// CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
 // -----
@@ -103,17 +106,20 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br ^[[bb1:.*]]
-// CHECK-NEXT:   ^[[bb1:.*]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
-// CHECK-NEXT:   ^[[bb2]]:
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]:
-// CHECK-NEXT:     cf.br ^[[bb4:.*]]
-// CHECK-NEXT:   ^[[bb4]]:
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
 // -----
@@ -165,13 +171,16 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br ^[[bb1:.*]]
-// CHECK-NEXT:   ^[[bb1]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
-// CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 580a97d3a851b..aa30900f76a33 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-ALL:       ^[[bb2]]
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]]:
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-CF:       ^[[bb2]]:
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-CF:       ^[[bb2]](%{{.*}}: i32)
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]]:
+// DET-CF:       ^[[bb3]](%{{.*}}: i32)
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 414d9b94cbf53..955c7be5ef4c8 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:         } -> tensor<i32>
 // DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-ALL:       ^[[bb2]]:
+// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         tensor.empty() : tensor<10xi32>
 // DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
@@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:           linalg.yield %{{.*}} : i32
 // DET-ALL:         } -> tensor<10xi32>
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
-// DET-ALL:       ^[[bb3]]
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
@@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
 // DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-CF:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-CF:         cf.cond_br %{{.*}}, ^bb2, ^bb3
-// DET-CF:       ^bb2:
+// DET-CF:         cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
 // DET-CF:         cf.br ^bb1(%{{.*}} : tensor<10xi32>)
-// DET-CF:       ^bb3:
+// DET-CF:       ^bb3(%{{.*}}: tensor<i32>)
 // DET-CF:         return %{{.*}} : tensor<i32>
 // DET-CF:       }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 913e78272db79..6d8d5fe71fca5 100644
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Mehdi Amini (joker-eph)

Changes

Reverts llvm/llvm-project#96871

Bots are broken.


Patch is 28.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97460.diff

12 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+2-7)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+12-132)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir (+8-12)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/detensorize_if.mlir (+38-29)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+3-3)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+4-4)
  • (modified) mlir/test/Transforms/make-isolated-from-above.mlir (+9-9)
  • (removed) mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir (-76)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 5227b22653eef..954485cfede3d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,10 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
-    // We don't want that the block structure changes invalidating the
-    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
-    // region simplification
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
-                                            config)))
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 412e2456295ad..4c0f15bafbaba 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -17,15 +16,11 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
-#include <iterator>
 
 using namespace mlir;
 
@@ -704,8 +699,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
-                                                       SmallVector<Value, 8>());
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(
+        1 + blocksToMerge.size(),
+        SmallVector<Value, 8>(operandsToMerge.size()));
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -716,22 +712,13 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        Value operandVal = operand.get();
-        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
-                              operandVal);
-        if (it == newArguments[i].end()) {
-          newArguments[i].push_back(operandVal);
-          // Update the operand and insert an argument if this is the leader.
-          if (i == 0) {
-            operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                                 operandVal.getLoc()));
-          }
-        } else if (i == 0) {
-          // If this is the leader, update the operand but do not insert a new
-          // argument. Instead, the opearand should point to one of the
-          // arguments we already passed (and that contained `operandVal`)
-          operand.set(leaderBlock->getArgument(
-              std::distance(newArguments[i].begin(), it)));
+        newArguments[i][it.index()] = operand.get();
+
+        // Update the operand and insert an argument if this is the leader.
+        if (i == 0) {
+          Value operandVal = operand.get();
+          operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                               operandVal.getLoc()));
         }
       }
     }
@@ -831,109 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
-static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
-                                            Block &block) {
-  SmallVector<size_t> argsToErase;
-
-  // Go through the arguments of the block
-  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
-    bool sameArg = true;
-    Value commonValue;
-
-    // Go through the block predecessor and flag if they pass to the block
-    // different values for the same argument
-    for (auto predIt = block.pred_begin(), predE = block.pred_end();
-         predIt != predE; ++predIt) {
-      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
-      if (!branch) {
-        sameArg = false;
-        break;
-      }
-      unsigned succIndex = predIt.getSuccessorIndex();
-      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      auto operands = succOperands.getForwardedOperands();
-      if (!commonValue) {
-        commonValue = operands[argIdx];
-      } else {
-        if (operands[argIdx] != commonValue) {
-          sameArg = false;
-          break;
-        }
-      }
-    }
-
-    // If they are passing the same value, drop the argument
-    if (commonValue && sameArg) {
-      argsToErase.push_back(argIdx);
-
-      // Remove the argument from the block
-      Value argVal = block.getArgument(argIdx);
-      rewriter.replaceAllUsesWith(argVal, commonValue);
-    }
-  }
-
-  // Remove the arguments
-  for (auto argIdx : llvm::reverse(argsToErase)) {
-    block.eraseArgument(argIdx);
-
-    // Remove the argument from the branch ops
-    for (auto predIt = block.pred_begin(), predE = block.pred_end();
-         predIt != predE; ++predIt) {
-      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
-      unsigned succIndex = predIt.getSuccessorIndex();
-      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      succOperands.erase(argIdx);
-    }
-  }
-  return success(!argsToErase.empty());
-}
-
-/// This optimization drops redundant argument to blocks. I.e., if a given
-/// argument to a block receives the same value from each of the block
-/// predecessors, we can remove the argument from the block and use directly the
-/// original value. This is a simple example:
-///
-/// %cond = llvm.call @rand() : () -> i1
-/// %val0 = llvm.mlir.constant(1 : i64) : i64
-/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(3 : i64) : i64
-/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
-/// : i64)
-///
-/// ^bb1(%arg0 : i64, %arg1 : i64):
-///    llvm.call @foo(%arg0, %arg1)
-///
-/// The previous IR can be rewritten as:
-/// %cond = llvm.call @rand() : () -> i1
-/// %val0 = llvm.mlir.constant(1 : i64) : i64
-/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(3 : i64) : i64
-/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
-///
-/// ^bb1(%arg0 : i64):
-///    llvm.call @foo(%val0, %arg0)
-///
-static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
-                                            MutableArrayRef<Region> regions) {
-  llvm::SmallSetVector<Region *, 1> worklist;
-  for (auto &region : regions)
-    worklist.insert(&region);
-  bool anyChanged = false;
-  while (!worklist.empty()) {
-    Region *region = worklist.pop_back_val();
-
-    // Add any nested regions to the worklist.
-    for (Block &block : *region) {
-      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
-
-      for (auto &op : block)
-        for (auto &nestedRegion : op.getRegions())
-          worklist.insert(&nestedRegion);
-    }
-  }
-  return success(anyChanged);
-}
-
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -948,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  bool droppedRedundantArguments = false;
-  if (mergeBlocks) {
+  if (mergeBlocks)
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
-    droppedRedundantArguments =
-        succeeded(dropRedundantArguments(rewriter, regions));
-  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks || droppedRedundantArguments);
+                 mergedIdenticalBlocks);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 8e14990502143..5e8104f83cc4d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,24 +186,20 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4:
+//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
-//   CHECK-NOT: bufferization.dealloc
-//   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index 50a2d6bf532aa..d1a89226fdb58 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}
-// CHECK: ^{{.*}}:
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
+// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index c728ad21d2209..8d17763c04b6c 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,15 +42,18 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br
-// CHECK-NEXT:   ^[[bb1:.*]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
-// CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
 // -----
@@ -103,17 +106,20 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br ^[[bb1:.*]]
-// CHECK-NEXT:   ^[[bb1:.*]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
-// CHECK-NEXT:   ^[[bb2]]:
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]:
-// CHECK-NEXT:     cf.br ^[[bb4:.*]]
-// CHECK-NEXT:   ^[[bb4]]:
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
 // -----
@@ -165,13 +171,16 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br ^[[bb1:.*]]
-// CHECK-NEXT:   ^[[bb1]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
-// CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 580a97d3a851b..aa30900f76a33 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-ALL:       ^[[bb2]]
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]]:
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-CF:       ^[[bb2]]:
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-CF:       ^[[bb2]](%{{.*}}: i32)
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]]:
+// DET-CF:       ^[[bb3]](%{{.*}}: i32)
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 414d9b94cbf53..955c7be5ef4c8 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:         } -> tensor<i32>
 // DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-ALL:       ^[[bb2]]:
+// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         tensor.empty() : tensor<10xi32>
 // DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
@@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:           linalg.yield %{{.*}} : i32
 // DET-ALL:         } -> tensor<10xi32>
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
-// DET-ALL:       ^[[bb3]]
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
@@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
 // DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-CF:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-CF:         cf.cond_br %{{.*}}, ^bb2, ^bb3
-// DET-CF:       ^bb2:
+// DET-CF:         cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
 // DET-CF:         cf.br ^bb1(%{{.*}} : tensor<10xi32>)
-// DET-CF:       ^bb3:
+// DET-CF:       ^bb3(%{{.*}}: tensor<i32>)
 // DET-CF:         return %{{.*}} : tensor<i32>
 // DET-CF:       }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 913e78272db79..6d8d5fe71fca5 100644
...
[truncated]

lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:linalg mlir skip-precommit-approval PR for CI feedback, not intended for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants