diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index a51b00271f0ae..60113bdef16a2 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1126,6 +1126,29 @@ struct ConversionConfig { RewriterBase::Listener *listener = nullptr; }; +//===----------------------------------------------------------------------===// +// Reconcile Unrealized Casts +//===----------------------------------------------------------------------===// + +/// Try to reconcile all given UnrealizedConversionCastOps and store the +/// left-over ops in `remainingCastOps` (if provided). +/// +/// This function processes cast ops in a worklist-driven fashion. For each +/// cast op, if the chain of input casts eventually reaches a cast op where the +/// input types match the output types of the matched op, replace the matched +/// op with the inputs. +/// +/// Example: +/// %1 = unrealized_conversion_cast %0 : !A to !B +/// %2 = unrealized_conversion_cast %1 : !B to !C +/// %3 = unrealized_conversion_cast %2 : !C to !A +/// +/// In the above example, %0 can be used instead of %3 and all cast ops are +/// folded away. +void reconcileUnrealizedCasts( + ArrayRef castOps, + SmallVectorImpl *remainingCastOps = nullptr); + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp index 12e0029cebfd0..2ce6dcbb49014 100644 --- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp +++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS @@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts ReconcileUnrealizedCasts() = default; void runOnOperation() override { - // Gather all unrealized_conversion_cast ops. - SetVector worklist; + SmallVector ops; getOperation()->walk( - [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); }); - - // Helper function that adds all operands to the worklist that are an - // unrealized_conversion_cast op result. - auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { - for (Value v : castOp.getInputs()) - if (auto inputCastOp = v.getDefiningOp()) - worklist.insert(inputCastOp); - }; - - // Helper function that return the unrealized_conversion_cast op that - // defines all inputs of the given op (in the same order). Return "nullptr" - // if there is no such op. - auto getInputCast = - [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp { - if (castOp.getInputs().empty()) - return {}; - auto inputCastOp = castOp.getInputs() - .front() - .getDefiningOp(); - if (!inputCastOp) - return {}; - if (inputCastOp.getOutputs() != castOp.getInputs()) - return {}; - return inputCastOp; - }; - - // Process ops in the worklist bottom-to-top. - while (!worklist.empty()) { - UnrealizedConversionCastOp castOp = worklist.pop_back_val(); - if (castOp->use_empty()) { - // DCE: If the op has no users, erase it. Add the operands to the - // worklist to find additional DCE opportunities. - enqueueOperands(castOp); - castOp->erase(); - continue; - } - - // Traverse the chain of input cast ops to see if an op with the same - // input types can be found. - UnrealizedConversionCastOp nextCast = castOp; - while (nextCast) { - if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { - // Found a cast where the input types match the output types of the - // matched op. We can directly use those inputs and the matched op can - // be removed. - enqueueOperands(castOp); - castOp.replaceAllUsesWith(nextCast.getInputs()); - castOp->erase(); - break; - } - nextCast = getInputCast(nextCast); - } - } + [&](UnrealizedConversionCastOp castOp) { ops.push_back(castOp); }); + reconcileUnrealizedCasts(ops); } }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 8a4c7463a69a9..6238a257b2ffd 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult( return success(); } +//===----------------------------------------------------------------------===// +// Reconcile Unrealized Casts +//===----------------------------------------------------------------------===// + +void mlir::reconcileUnrealizedCasts( + ArrayRef castOps, + SmallVectorImpl *remainingCastOps) { + SetVector worklist(castOps.begin(), + castOps.end()); + // This set is maintained only if `remainingCastOps` is provided. + DenseSet erasedOps; + + // Helper function that adds all operands to the worklist that are an + // unrealized_conversion_cast op result. + auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { + for (Value v : castOp.getInputs()) + if (auto inputCastOp = v.getDefiningOp()) + worklist.insert(inputCastOp); + }; + + // Helper function that return the unrealized_conversion_cast op that + // defines all inputs of the given op (in the same order). Return "nullptr" + // if there is no such op. + auto getInputCast = + [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp { + if (castOp.getInputs().empty()) + return {}; + auto inputCastOp = + castOp.getInputs().front().getDefiningOp(); + if (!inputCastOp) + return {}; + if (inputCastOp.getOutputs() != castOp.getInputs()) + return {}; + return inputCastOp; + }; + + // Process ops in the worklist bottom-to-top. + while (!worklist.empty()) { + UnrealizedConversionCastOp castOp = worklist.pop_back_val(); + if (castOp->use_empty()) { + // DCE: If the op has no users, erase it. Add the operands to the + // worklist to find additional DCE opportunities. + enqueueOperands(castOp); + if (remainingCastOps) + erasedOps.insert(castOp.getOperation()); + castOp->erase(); + continue; + } + + // Traverse the chain of input cast ops to see if an op with the same + // input types can be found. + UnrealizedConversionCastOp nextCast = castOp; + while (nextCast) { + if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { + // Found a cast where the input types match the output types of the + // matched op. We can directly use those inputs and the matched op can + // be removed. + enqueueOperands(castOp); + castOp.replaceAllUsesWith(nextCast.getInputs()); + if (remainingCastOps) + erasedOps.insert(castOp.getOperation()); + castOp->erase(); + break; + } + nextCast = getInputCast(nextCast); + } + } + + if (remainingCastOps) + for (UnrealizedConversionCastOp op : castOps) + if (!erasedOps.contains(op.getOperation())) + remainingCastOps->push_back(op); +} + //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===//