Skip to content

[mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ing #89552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to GreedyRewriteConfig that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Apr 21, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 21, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to GreedyRewriteConfig that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.


Full diff: https://github.com/llvm/llvm-project/pull/89552.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+4)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+2)
  • (modified) mlir/lib/Transforms/Canonicalizer.cpp (+2)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+2-2)
  • (modified) mlir/test/Transforms/test-canonicalize.mlir (+14)
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9c..880426c2411bcf 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -47,6 +47,10 @@ class GreedyRewriteConfig {
   /// Note: Only applicable when simplifying entire regions.
   bool enableRegionSimplification = true;
 
+  /// If set to "true", constants are CSE'd (even across multiple regions that
+  /// are in a parent-ancestor relationship).
+  bool cseConstants = true;
+
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
   /// disable this iteration limit.
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27e..549161c96030d3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> {
     Option<"enableRegionSimplification", "region-simplify", "bool",
            /*default=*/"true",
            "Perform control flow optimizations to the region tree">,
+    Option<"cseConstants", "cse-constants", "bool", /*default=*/"true",
+           "CSE constant operations">,
     Option<"maxIterations", "max-iterations", "int64_t",
            /*default=*/"10",
            "Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee55..2600df32b69c1d 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
       : config(config) {
     this->topDownProcessingEnabled = config.useTopDownTraversal;
     this->enableRegionSimplification = config.enableRegionSimplification;
+    this->cseConstants = config.cseConstants;
     this->maxIterations = config.maxIterations;
     this->maxNumRewrites = config.maxNumRewrites;
     this->disabledPatterns = disabledPatterns;
@@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     // Set the config from possible pass options set in the meantime.
     config.useTopDownTraversal = topDownProcessingEnabled;
     config.enableRegionSimplification = enableRegionSimplification;
+    config.cseConstants = cseConstants;
     config.maxIterations = maxIterations;
     config.maxNumRewrites = maxNumRewrites;
 
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff2..cf4a192a0281d7 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
-        if (!insertKnownConstant(op))
+        if (!config.cseConstants || !insertKnownConstant(op))
           addToWorklist(op);
       });
     } else {
       // Add all nested operations to the worklist in preorder.
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-        if (!insertKnownConstant(op)) {
+        if (!config.cseConstants || !insertKnownConstant(op)) {
           addToWorklist(op);
           return WalkResult::advance();
         }
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf4..98eae142d1870e 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
 func.func @remove_op_with_inner_ops_pattern() {
@@ -89,3 +90,16 @@ func.func @test_region_simplify() {
 ^bb1:
   return
 }
+
+// CHECK-LABEL: do_not_cse_constant
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: return %[[c0]], %[[c0]]
+// NO-CSE-LABEL: do_not_cse_constant
+// NO-CSE: %[[c0:.*]] = arith.constant 0 : index
+// NO-CSE: %[[c1:.*]] = arith.constant 0 : index
+// NO-CSE: return %[[c0]], %[[c1]]
+func.func @do_not_cse_constant() -> (index, index) {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 0 : index
+  return %0, %1 : index, index
+}
\ No newline at end of file

…nstant CSE'ing

By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/greedy_rewrite_cse_constants branch from 29f45ee to 9bb74b5 Compare April 23, 2024 22:43
@jpienaar
Copy link
Member

I'm not too sure here. While I've seen folks "need" it, in most cases the issue turned out to be something adjacent that would have caused problems later. This also just feels like a base case for an expression, so that a more generic equivalence check is maybe what folks are reaching for.

jpienaar added a commit to jpienaar/llvm-project that referenced this pull request Dec 12, 2024
This partially incorporates llvm#89552. I haven't exposed it on
canonicalizer pass as that could be distinct discussion.
jpienaar added a commit to jpienaar/llvm-project that referenced this pull request Dec 20, 2024
This partially incorporates llvm#89552. I haven't exposed it on
canonicalizer pass as that could be distinct discussion.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants