Skip to content

[mlir][Transforms][NFC] Move ReconcileUnrealizedCasts implementation #104671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 23, 2024

Conversation

matthias-springer
Copy link
Member

Move the implementation of ReconcileUnrealizedCasts to DialectConversion.cpp, so that it can be called from there in a future commit.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as ReconcileUnrealizedCasts will perform these kind of foldings on fully materialized IR.

Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.
@matthias-springer matthias-springer marked this pull request as ready for review August 17, 2024 10:08
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 17, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Move the implementation of ReconcileUnrealizedCasts to DialectConversion.cpp, so that it can be called from there in a future commit.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as ReconcileUnrealizedCasts will perform these kind of foldings on fully materialized IR.


Full diff: https://github.com/llvm/llvm-project/pull/104671.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+23)
  • (modified) mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp (+3-55)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+74)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0aeb..86f0337dd90dfe 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
   RewriterBase::Listener *listener = nullptr;
 };
 
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided).
+///
+/// This function processes cast ops in a worklist-driven fashion. For each
+/// cast op, if the chain of input casts eventually reaches a cast op where the
+/// input types match the output types of the matched op, replace the matched
+/// op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
+void reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 12e0029cebfd0d..d01e3dcbe8cc45 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
   ReconcileUnrealizedCasts() = default;
 
   void runOnOperation() override {
-    // Gather all unrealized_conversion_cast ops.
-    SetVector<UnrealizedConversionCastOp> worklist;
+    SmallVector<UnrealizedConversionCastOp> ops;
     getOperation()->walk(
         [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
-
-    // Helper function that adds all operands to the worklist that are an
-    // unrealized_conversion_cast op result.
-    auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
-      for (Value v : castOp.getInputs())
-        if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
-          worklist.insert(inputCastOp);
-    };
-
-    // Helper function that return the unrealized_conversion_cast op that
-    // defines all inputs of the given op (in the same order). Return "nullptr"
-    // if there is no such op.
-    auto getInputCast =
-        [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
-      if (castOp.getInputs().empty())
-        return {};
-      auto inputCastOp = castOp.getInputs()
-                             .front()
-                             .getDefiningOp<UnrealizedConversionCastOp>();
-      if (!inputCastOp)
-        return {};
-      if (inputCastOp.getOutputs() != castOp.getInputs())
-        return {};
-      return inputCastOp;
-    };
-
-    // Process ops in the worklist bottom-to-top.
-    while (!worklist.empty()) {
-      UnrealizedConversionCastOp castOp = worklist.pop_back_val();
-      if (castOp->use_empty()) {
-        // DCE: If the op has no users, erase it. Add the operands to the
-        // worklist to find additional DCE opportunities.
-        enqueueOperands(castOp);
-        castOp->erase();
-        continue;
-      }
-
-      // Traverse the chain of input cast ops to see if an op with the same
-      // input types can be found.
-      UnrealizedConversionCastOp nextCast = castOp;
-      while (nextCast) {
-        if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
-          // Found a cast where the input types match the output types of the
-          // matched op. We can directly use those inputs and the matched op can
-          // be removed.
-          enqueueOperands(castOp);
-          castOp.replaceAllUsesWith(nextCast.getInputs());
-          castOp->erase();
-          break;
-        }
-        nextCast = getInputCast(nextCast);
-      }
-    }
+    reconcileUnrealizedCasts(ops);
   }
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..0da8eabadb4ee1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
+  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+                                                 castOps.end());
+  // This set is maintained only if `remainingCastOps` is provided.
+  DenseSet<Operation *> erasedOps;
+
+  // Helper function that adds all operands to the worklist that are an
+  // unrealized_conversion_cast op result.
+  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+    for (Value v : castOp.getInputs())
+      if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+        worklist.insert(inputCastOp);
+  };
+
+  // Helper function that return the unrealized_conversion_cast op that
+  // defines all inputs of the given op (in the same order). Return "nullptr"
+  // if there is no such op.
+  auto getInputCast =
+      [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+    if (castOp.getInputs().empty())
+      return {};
+    auto inputCastOp =
+        castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+    if (!inputCastOp)
+      return {};
+    if (inputCastOp.getOutputs() != castOp.getInputs())
+      return {};
+    return inputCastOp;
+  };
+
+  // Process ops in the worklist bottom-to-top.
+  while (!worklist.empty()) {
+    UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+    if (castOp->use_empty()) {
+      // DCE: If the op has no users, erase it. Add the operands to the
+      // worklist to find additional DCE opportunities.
+      enqueueOperands(castOp);
+      if (remainingCastOps)
+        erasedOps.insert(castOp.getOperation());
+      castOp->erase();
+      continue;
+    }
+
+    // Traverse the chain of input cast ops to see if an op with the same
+    // input types can be found.
+    UnrealizedConversionCastOp nextCast = castOp;
+    while (nextCast) {
+      if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+        // Found a cast where the input types match the output types of the
+        // matched op. We can directly use those inputs and the matched op can
+        // be removed.
+        enqueueOperands(castOp);
+        castOp.replaceAllUsesWith(nextCast.getInputs());
+        if (remainingCastOps)
+          erasedOps.insert(castOp.getOperation());
+        castOp->erase();
+        break;
+      }
+      nextCast = getInputCast(nextCast);
+    }
+  }
+
+  if (remainingCastOps)
+    for (UnrealizedConversionCastOp op : castOps)
+      if (!erasedOps.contains(op.getOperation()))
+        remainingCastOps->push_back(op);
+}
+
 //===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//

@matthias-springer
Copy link
Member Author

For reference, #104668 is still in development, but it's going to look somewhat like that.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/move_reconcile_unre branch from db989d5 to ed5a2dc Compare August 17, 2024 10:39
@matthias-springer matthias-springer merged commit a9f6224 into main Aug 23, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/move_reconcile_unre branch August 23, 2024 15:46
cjdb pushed a commit to cjdb/llvm-project that referenced this pull request Aug 23, 2024
llvm#104671)

Move the implementation of `ReconcileUnrealizedCasts` to
`DialectConversion.cpp`, so that it can be called from there in a future
commit.

This commit is in preparation of decoupling argument/source/target
materializations from the dialect conversion framework. The existing
logic around unresolved materializations that predicts IR changes to
decide if a cast op can be folded/erased will become obsolete, as
`ReconcileUnrealizedCasts` will perform these kind of foldings on fully
materialized IR.

---------

Co-authored-by: Markus Böck <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants