-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Transforms][NFC] Move ReconcileUnrealizedCasts
implementation
#104671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Move ReconcileUnrealizedCasts
implementation
#104671
Conversation
Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesMove the implementation of This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as Full diff: https://github.com/llvm/llvm-project/pull/104671.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0aeb..86f0337dd90dfe 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
RewriterBase::Listener *listener = nullptr;
};
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided).
+///
+/// This function processes cast ops in a worklist-driven fashion. For each
+/// cast op, if the chain of input casts eventually reaches a cast op where the
+/// input types match the output types of the matched op, replace the matched
+/// op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
+void reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 12e0029cebfd0d..d01e3dcbe8cc45 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- // Gather all unrealized_conversion_cast ops.
- SetVector<UnrealizedConversionCastOp> worklist;
+ SmallVector<UnrealizedConversionCastOp> ops;
getOperation()->walk(
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
-
- // Helper function that adds all operands to the worklist that are an
- // unrealized_conversion_cast op result.
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
- for (Value v : castOp.getInputs())
- if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
- worklist.insert(inputCastOp);
- };
-
- // Helper function that return the unrealized_conversion_cast op that
- // defines all inputs of the given op (in the same order). Return "nullptr"
- // if there is no such op.
- auto getInputCast =
- [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
- if (castOp.getInputs().empty())
- return {};
- auto inputCastOp = castOp.getInputs()
- .front()
- .getDefiningOp<UnrealizedConversionCastOp>();
- if (!inputCastOp)
- return {};
- if (inputCastOp.getOutputs() != castOp.getInputs())
- return {};
- return inputCastOp;
- };
-
- // Process ops in the worklist bottom-to-top.
- while (!worklist.empty()) {
- UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- castOp->erase();
- continue;
- }
-
- // Traverse the chain of input cast ops to see if an op with the same
- // input types can be found.
- UnrealizedConversionCastOp nextCast = castOp;
- while (nextCast) {
- if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
- // Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
- enqueueOperands(castOp);
- castOp.replaceAllUsesWith(nextCast.getInputs());
- castOp->erase();
- break;
- }
- nextCast = getInputCast(nextCast);
- }
- }
+ reconcileUnrealizedCasts(ops);
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..0da8eabadb4ee1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
+ SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+ castOps.end());
+ // This set is maintained only if `remainingCastOps` is provided.
+ DenseSet<Operation *> erasedOps;
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp =
+ castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
+ // matched op. We can directly use those inputs and the matched op can
+ // be removed.
+ enqueueOperands(castOp);
+ castOp.replaceAllUsesWith(nextCast.getInputs());
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ break;
+ }
+ nextCast = getInputCast(nextCast);
+ }
+ }
+
+ if (remainingCastOps)
+ for (UnrealizedConversionCastOp op : castOps)
+ if (!erasedOps.contains(op.getOperation()))
+ remainingCastOps->push_back(op);
+}
+
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
|
For reference, #104668 is still in development, but it's going to look somewhat like that. |
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
Outdated
Show resolved
Hide resolved
…edCasts.cpp Co-authored-by: Markus Böck <[email protected]>
Co-authored-by: Markus Böck <[email protected]>
db989d5
to
ed5a2dc
Compare
llvm#104671) Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR. --------- Co-authored-by: Markus Böck <[email protected]>
Move the implementation of
ReconcileUnrealizedCasts
toDialectConversion.cpp
, so that it can be called from there in a future commit.This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as
ReconcileUnrealizedCasts
will perform these kind of foldings on fully materialized IR.