diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp index 769e14b1316d6..b68fe6ee0c747 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp @@ -125,7 +125,7 @@ class InlineElementalsPass mlir::RewritePatternSet patterns(context); patterns.insert(context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR elemental inlining"); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index 36fae90c83fd6..091ed7ed999df 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -520,8 +520,8 @@ class LowerHLFIRIntrinsics config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - module, std::move(patterns), config))) { + if (mlir::failed( + mlir::applyPatternsGreedily(module, std::move(patterns), config))) { mlir::emitError(mlir::UnknownLoc::get(context), "failure in HLFIR intrinsic lowering"); signalPassFailure(); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index c152c27c0a05b..bf3cf861e46f4 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -1372,7 +1372,7 @@ class OptimizedBufferizationPass // patterns.insert>(context); // patterns.insert>(context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR optimized bufferization"); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 28325bc8e5489..bf3d261e7e883 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -491,7 +491,7 @@ class SimplifyHLFIRIntrinsics patterns.insert(context); patterns.insert(context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR intrinsic simplification"); diff --git a/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp b/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp index fd58375da618a..fab1f0299ede9 100644 --- a/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp +++ b/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp @@ -39,8 +39,7 @@ struct AlgebraicSimplification void AlgebraicSimplification::runOnOperation() { RewritePatternSet patterns(&getContext()); populateMathAlgebraicSimplificationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); } std::unique_ptr fir::createAlgebraicSimplificationPass() { diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp index 2c9c73e8a5394..eb59045a5fde7 100644 --- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp @@ -154,7 +154,7 @@ class AssumedRankOpConversion mlir::GreedyRewriteConfig config; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - (void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config); + (void)applyPatternsGreedily(mod, std::move(patterns), config); } }; } // namespace diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp index eef6f047fc1bf..562f3058f20f3 100644 --- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp +++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp @@ -173,8 +173,8 @@ class ConstantArgumentGlobalisationOpt config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; patterns.insert(context, *di); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - mod, std::move(patterns), config))) { + if (mlir::failed( + mlir::applyPatternsGreedily(mod, std::move(patterns), config))) { mlir::emitError(mod.getLoc(), "error in constant globalisation optimization\n"); signalPassFailure(); diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 0c474f463f09c..f9281000d21f0 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -793,8 +793,8 @@ void StackArraysPass::runOnOperation() { config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; patterns.insert(&context, *candidateOps); - if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, - std::move(patterns), config))) { + if (mlir::failed(mlir::applyOpPatternsGreedily( + opsToConvert, std::move(patterns), config))) { mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); signalPassFailure(); } diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index da392b8289332..d15e7e5a80678 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -358,7 +358,7 @@ which point the driver finishes. This driver comes in two fashions: -* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to +* `applyPatternsGreedily` ("region-based driver") applies patterns to all ops in a given region or a given container op (but not the container op itself). I.e., the worklist is initialized with all containing ops. * `applyOpPatternsAndFold` ("op-based driver") applies patterns to the diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp index 8166aa238bf2c..8c79a07537933 100644 --- a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp +++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp @@ -39,7 +39,7 @@ class StandaloneSwitchBarFoo RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + if (failed(applyPatternsGreedily(getOperation(), patternSet))) signalPassFailure(); } }; diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index eaff85804f6b3..110b4f64856eb 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -91,6 +91,13 @@ class GreedyRewriteConfig { /// An optional listener that should be notified about IR modifications. RewriterBase::Listener *listener = nullptr; + + /// Whether this should fold while greedily rewriting. + bool fold = true; + + /// If set to "true", constants are CSE'd (even across multiple regions that + /// are in a parent-ancestor relationship). + bool cseConstants = true; }; //===----------------------------------------------------------------------===// @@ -104,8 +111,8 @@ class GreedyRewriteConfig { /// The greedy rewrite may prematurely stop after a maximum number of /// iterations, which can be configured in the configuration parameter. /// -/// Also performs folding and simple dead-code elimination before attempting to -/// match any of the provided patterns. +/// Also performs simple dead-code elimination before attempting to match any of +/// the provided patterns. /// /// A region scope can be set in the configuration parameter. By default, the /// scope is set to the specified region. Only in-scope ops are added to the @@ -117,10 +124,20 @@ class GreedyRewriteConfig { /// /// Note: This method does not apply patterns to the region's parent operation. LogicalResult +applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr); +/// Same as `applyPatternsAndGreedily` above with folding. +/// FIXME: Remove this once transition to above is complieted. +LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily") +inline LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig(), - bool *changed = nullptr); + bool *changed = nullptr) { + config.fold = true; + return applyPatternsGreedily(region, patterns, config, changed); +} /// Rewrite ops nested under the given operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy @@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region ®ion, /// The greedy rewrite may prematurely stop after a maximum number of /// iterations, which can be configured in the configuration parameter. /// -/// Also performs folding and simple dead-code elimination before attempting to -/// match any of the provided patterns. +/// Also performs simple dead-code elimination before attempting to match any of +/// the provided patterns. /// /// This overload runs a separate greedy rewrite for each region of the /// specified op. A region scope can be set in the configuration parameter. By @@ -147,23 +164,32 @@ applyPatternsAndFoldGreedily(Region ®ion, /// /// Note: This method does not apply patterns to the given operation itself. inline LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig(), - bool *changed = nullptr) { +applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr) { bool anyRegionChanged = false; bool failed = false; for (Region ®ion : op->getRegions()) { bool regionChanged; - failed |= - applyPatternsAndFoldGreedily(region, patterns, config, ®ionChanged) - .failed(); + failed |= applyPatternsGreedily(region, patterns, config, ®ionChanged) + .failed(); anyRegionChanged |= regionChanged; } if (changed) *changed = anyRegionChanged; return failure(failed); } +/// Same as `applyPatternsGreedily` above with folding. +/// FIXME: Remove this once transition to above is complieted. +LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily") +inline LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr) { + config.fold = true; + return applyPatternsGreedily(op, patterns, config, changed); +} /// Rewrite the specified ops by repeatedly applying the highest benefit /// patterns in a greedy worklist driven manner until a fixpoint is reached. @@ -171,8 +197,8 @@ applyPatternsAndFoldGreedily(Operation *op, /// The greedy rewrite may prematurely stop after a maximum number of /// iterations, which can be configured in the configuration parameter. /// -/// Also performs folding and simple dead-code elimination before attempting to -/// match any of the provided patterns. +/// Also performs simple dead-code elimination before attempting to match any of +/// the provided patterns. /// /// Newly created ops and other pre-existing ops that use results of rewritten /// ops or supply operands to such ops are also processed, unless such ops are @@ -180,24 +206,36 @@ applyPatternsAndFoldGreedily(Operation *op, /// regardless of `strictMode`). /// /// In addition to strictness, a region scope can be specified. Only ops within -/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`, +/// the scope are simplified. This is similar to `applyPatternsGreedily`, /// where only ops within the given region/op are simplified by default. If no /// scope is specified, it is assumed to be the first common enclosing region of /// the given ops. /// /// Note that ops in `ops` could be erased as result of folding, becoming dead, /// or via pattern rewrites. If more far reaching simplification is desired, -/// `applyPatternsAndFoldGreedily` should be used. +/// `applyPatternsGreedily` should be used. /// /// Returns "success" if the iterative process converged (i.e., fixpoint was /// reached) and no more patterns can be matched. `changed` is set to "true" if /// the IR was modified at all. `allOpsErased` is set to "true" if all ops in /// `ops` were erased. LogicalResult +applyOpPatternsGreedily(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr, bool *allErased = nullptr); +/// Same as `applyOpPatternsGreedily` with folding. +/// FIXME: Remove this once transition to above is complieted. +LLVM_DEPRECATED("Use applyOpPatternsGreedily() instead", + "applyOpPatternsGreedily") +inline LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig(), - bool *changed = nullptr, bool *allErased = nullptr); + bool *changed = nullptr, bool *allErased = nullptr) { + config.fold = true; + return applyOpPatternsGreedily(ops, patterns, config, changed, allErased); +} } // namespace mlir diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 379f09cf5cc26..c4717ca613319 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -289,8 +289,7 @@ MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig) { - return wrap( - mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 6b9cbaf57676c..a8283023afc53 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -385,6 +385,6 @@ void ArithToAMDGPUConversionPass::runOnOperation() { arith::populateArithToAMDGPUConversionPatterns( patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, *maybeChipset); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp index 5aa2a098b1762..cbe0b3fda3410 100644 --- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp +++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp @@ -117,8 +117,7 @@ struct ArithToArmSMEConversionPass final void runOnOperation() override { RewritePatternSet patterns(&getContext()); arith::populateArithToArmSMEConversionPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index bdbf276d79b22..de8bfd6a17103 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -59,8 +59,7 @@ class ConvertArmNeon2dToIntr RewritePatternSet patterns(context); populateConvertArmNeon2dToIntrPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index b343cf71e3a2e..e022d3ce6f636 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -271,7 +271,7 @@ struct LowerGpuOpsToNVVMOpsPass { RewritePatternSet patterns(m.getContext()); populateGpuRewritePatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) + if (failed(applyPatternsGreedily(m, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index aa4d3b70329fb..d52a86987b1ce 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -271,7 +271,7 @@ struct LowerGpuOpsToROCDLOpsPass RewritePatternSet patterns(ctx); populateGpuRewritePatterns(patterns); arith::populateExpandBFloat16Patterns(patterns); - (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + (void)applyPatternsGreedily(m, std::move(patterns)); } LLVMTypeConverter converter(ctx, options); diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 6dd89ecf4d5c2..e1de125ccaede 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -427,8 +427,7 @@ struct ConvertMeshToMPIPass ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>( ctx); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)); + (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp index 7df1407da6f97..d92027a5e3d46 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -62,7 +62,7 @@ class ConvertShapeConstraints RewritePatternSet patterns(context); populateConvertShapeConstraintsConversionPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp index cc00bf4ca190a..7419276651ae2 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp @@ -33,7 +33,7 @@ void ConvertVectorToArmSMEPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateVectorToArmSMEPatterns(patterns, getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr mlir::createConvertVectorToArmSMEPass() { diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 034f3e2d16e94..5b4414d67fdac 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -1326,8 +1326,7 @@ struct ConvertVectorToGPUPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); IRRewriter rewriter(&getContext()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 2d94c2f2e85a0..2c4c5ada9815d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -82,7 +82,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorInsertExtractStridedSliceTransforms(patterns); populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } // Convert to the LLVM IR dialect. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 3a4dc806efe97..01bc65c841e94 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1730,12 +1730,12 @@ struct ConvertVectorToSCFPass RewritePatternSet lowerTransferPatterns(&getContext()); mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( lowerTransferPatterns); - (void)applyPatternsAndFoldGreedily(getOperation(), - std::move(lowerTransferPatterns)); + (void)applyPatternsGreedily(getOperation(), + std::move(lowerTransferPatterns)); RewritePatternSet patterns(&getContext()); populateVectorToSCFConversionPatterns(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 1232d8795d4dc..8041bdf7da19b 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -318,8 +318,7 @@ struct ConvertVectorToXeGPUPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorToXeGPUConversionPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index eb52297940722..9f7df7823d997 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -132,7 +132,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. - if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { + if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) { auto diag = emitDefiniteFailure() << "affine.min/max simplification did not converge"; return diag; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index 331b0f1b2c2b1..9ffe54f61ebbd 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -239,5 +239,5 @@ void AffineDataCopyGeneration::runOnOperation() { FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config); + (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index d7b218225bc9a..7e335ea929c4f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -198,8 +198,7 @@ class ExpandAffineIndexOpsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); populateAffineExpandIndexOpsPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp index bfcc1ddf91653..16ba16d5c798f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp @@ -79,8 +79,7 @@ class ExpandAffineIndexOpsAsAffinePass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); populateAffineExpandIndexOpsAsAffinePatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index 49618074ec224..31711ade3153b 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -111,5 +111,5 @@ void SimplifyAffineStructures::runOnOperation() { }); GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config); + (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index c5cc8bfeb0a64..0f2c889d4f390 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -318,8 +318,8 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp, GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; bool erased; - (void)applyOpPatternsAndFold(res.getOperation(), std::move(patterns), - config, /*changed=*/nullptr, &erased); + (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns), + config, /*changed=*/nullptr, &erased); if (!erased && !prologue) prologue = res; if (!erased) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 07d399adae0cd..4d3ead20fb5cd 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -425,8 +425,8 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; bool erased; - (void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, config, - /*changed=*/nullptr, &erased); + (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config, + /*changed=*/nullptr, &erased); if (erased) { if (folded) *folded = true; @@ -454,7 +454,7 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up // a sequence of affine.fors that are all perfectly nested). - (void)applyPatternsAndFoldGreedily( + (void)applyPatternsGreedily( hoistedIfOp->getParentWithTrait(), frozenPatterns); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index b54a53f5ef70e..5982f5f55549e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -489,7 +489,7 @@ struct IntRangeOptimizationsPass final GreedyRewriteConfig config; config.listener = &listener; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; @@ -518,7 +518,7 @@ struct IntRangeNarrowingPass final config.useTopDownTraversal = false; config.listener = &listener; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index ee1e374b25b04..23f2c2bf65e47 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -523,8 +523,7 @@ struct OuterProductFusionPass RewritePatternSet patterns(&getContext()); populateOuterProductFusionPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index 8b4bacd722712..d2ac850a5f70b 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -317,8 +317,7 @@ struct LegalizeVectorStorage void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateLegalizeVectorStoragePatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } ConversionTarget target(getContext()); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 273101ce5f3e7..1320523aa989d 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -931,7 +931,7 @@ void AsyncParallelForPass::runOnOperation() { [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { return builder.create(minTaskSize); }); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 5227b22653eef..de3ae82f87086 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -470,8 +470,8 @@ struct BufferDeallocationSimplificationPass config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal; populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp index 7670220dce776..d20c6966d4eb9 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp @@ -60,7 +60,7 @@ void EmptyTensorToAllocTensor::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateEmptyTensorToAllocTensorPattern(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp index 82bd031430d36..3385514375804 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp @@ -47,7 +47,7 @@ struct FormExpressionsPass RewritePatternSet patterns(context); populateExpressionPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns)))) + if (failed(applyPatternsGreedily(rootOp, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp index 004d73a77e535..a504101fb3f2f 100644 --- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp @@ -227,8 +227,7 @@ struct GpuDecomposeMemrefsPass populateGpuDecomposeMemrefsPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 0ffd8131b8934..2178555cb62f7 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -630,7 +630,7 @@ class GpuEliminateBarriersPass auto funcOp = getOperation(); RewritePatternSet patterns(&getContext()); mlir::populateGpuEliminateBarriersPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp index 8c33148d1d2d7..c1ec1df48e5b9 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -96,7 +96,7 @@ void NVVMOptimizeForTarget::runOnOperation() { MLIRContext *ctx = getOperation()->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 18fd24da395b7..221ca27b80fdd 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3511,7 +3511,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( TrackingListener listener(state, *this); GreedyRewriteConfig config; config.listener = &listener; - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config))) + if (failed(applyPatternsGreedily(target, std::move(patterns), config))) return emitDefaultDefiniteFailure(target); results.push_back(target); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 91d4efa3372b7..57344f986480d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -301,7 +301,7 @@ struct LinalgBlockPackMatmul }; linalg::populateBlockPackMatmulPatterns(patterns, controlFn); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index af38485291182..0e651f4cee4c3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -563,8 +563,7 @@ struct LinalgDetensorize RewritePatternSet canonPatterns(context); tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(canonPatterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns)))) signalPassFailure(); // Get rid of the dummy entry block we created in the beginning to work diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bb50347596910..9b97865990bfd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -831,7 +831,7 @@ struct LinalgFoldUnitExtentDimsPass } linalg::populateFoldUnitExtentDimsPatterns(patterns, options); populateMoveInitOperandsToInputPattern(patterns); - (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + (void)applyPatternsGreedily(op, std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index efc7934bc7d8a..3a57f368d4425 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2206,7 +2206,7 @@ struct LinalgElementwiseOpFusionPass // Use TopDownTraversal for compile time reasons GreedyRewriteConfig grc; grc.useTopDownTraversal = true; - (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc); + (void)applyPatternsGreedily(op, std::move(patterns), grc); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 7ab3fef5dd039..78cee47c497ed 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -89,7 +89,7 @@ struct LinalgGeneralizeNamedOpsPass void LinalgGeneralizeNamedOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateLinalgNamedOpsGeneralizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 2a1445fb92fdc..1f3336d2bfbb9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -113,7 +113,7 @@ struct LinalgInlineScalarOperandsPass MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); populateInlineConstantOperandsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + (void)applyPatternsGreedily(op, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 20a99491b6644..984f3f5a34ab1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -321,7 +321,7 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); // Just apply the patterns greedily. - (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns)); + (void)applyPatternsGreedily(enclosingOp, std::move(patterns)); } struct LowerToAffineLoops diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp index 84bde1bc0b846..bb1e974391878 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -152,7 +152,7 @@ struct LinalgNamedOpConversionPass Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateLinalgNamedOpConversionPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 748e2a1377930..512fb7555a6b7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -349,7 +349,7 @@ void LinalgSpecializeGenericOpsPass::runOnOperation() { populateLinalgGenericOpsSpecializationPatterns(patterns); populateDecomposeProjectedPermutationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp index 6b0d0f5e7466f..de950bac819c7 100644 --- a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp +++ b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp @@ -66,8 +66,7 @@ struct MathUpliftToFMA final void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateUpliftToFMAPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 92592d2345d75..aa008f8407b5d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1223,7 +1223,7 @@ struct ExpandStridedMetadataPass final void ExpandStridedMetadataPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateExpandStridedMetadataPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createExpandStridedMetadataPass() { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 96daf4c5972a4..8e927a60087fc 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -857,7 +857,7 @@ struct FoldMemRefAliasOpsPass final void FoldMemRefAliasOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateFoldMemRefAliasOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createFoldMemRefAliasOpsPass() { diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 792e722918306..dfcbaeb15ae5f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -195,7 +195,7 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } @@ -203,7 +203,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index 9f8189ae15e6d..3e93dc80b18ec 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -112,7 +112,7 @@ struct ForToWhileLoop : public impl::SCFForToWhileLoopBase { MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); + (void)applyPatternsGreedily(parentOp, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index c6d024c462e83..4ebd90dbcc1d5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -167,7 +167,7 @@ struct SCFForLoopCanonicalization MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); scf::populateSCFForLoopCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) + if (failed(applyPatternsGreedily(parentOp, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 5104ad4b3a303..b71ec985fa6a1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -331,7 +331,7 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase { MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx, peelFront, skipPartial); - (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); + (void)applyPatternsGreedily(parentOp, std::move(patterns)); // Drop the markers. parentOp->walk([](Operation *op) { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index ef5d4370e7810..90db42d479a19 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1430,7 +1430,7 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { GreedyRewriteConfig config; config.listener = this; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - return applyOpPatternsAndFold(ops, patterns.value(), config); + return applyOpPatternsGreedily(ops, patterns.value(), config); } void SliceTrackingListener::notifyOperationInserted( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp index 374c205897c8a..cc59c2116ed37 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp @@ -29,8 +29,7 @@ class CanonicalizeGLPass final void runOnOperation() override { RewritePatternSet patterns(&getContext()); spirv::populateSPIRVGLCanonicalizationPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 877ac87fb0fe5..29f7e8afe0773 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1354,7 +1354,7 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { // looking for newly created func ops. GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; - return applyPatternsAndFoldGreedily(op, std::move(patterns), config); + return applyPatternsGreedily(op, std::move(patterns), config); } LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { @@ -1366,7 +1366,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { auto options = vector::UnrollVectorOptions().setNativeShapeFn( [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); populateVectorUnrollPatterns(patterns, options); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); } @@ -1378,7 +1378,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { vector::VectorTransposeLowering::EltWise); vector::populateVectorTransposeLoweringPatterns(patterns, options); vector::populateVectorShapeCastLoweringPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); } @@ -1403,7 +1403,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); } return success(); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index d75c8552c9ad0..af1cf2a1373e3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -236,8 +236,7 @@ struct WebGPUPreparePass final populateSPIRVExpandExtendedMultiplicationPatterns(patterns); populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp index 655555f883544..e56742d52e131 100644 --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -207,7 +207,7 @@ void OutlineShapeComputationPass::runOnOperation() { MLIRContext *context = funcOp.getContext(); RewritePatternSet prevPatterns(context); prevPatterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns)))) + if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns)))) return signalPassFailure(); // initialize class member `onlyUsedByWithShapes` @@ -254,7 +254,7 @@ void OutlineShapeComputationPass::runOnOperation() { } // Apply patterns, note this also performs DCE. - if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) + if (failed(applyPatternsGreedily(funcOp, {}))) return signalPassFailure(); }); } diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp index e1cccd8fd5d65..d2b245f832e57 100644 --- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -55,7 +55,7 @@ class RemoveShapeConstraintsPass RewritePatternSet patterns(&ctx); populateRemoveShapeConstraintsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 8004bdb904b8a..1cac949b68c79 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -57,7 +57,7 @@ struct SparseAssembler : public impl::SparseAssemblerBase { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateSparseAssembler(patterns, directOut); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -73,7 +73,7 @@ struct SparseReinterpretMap auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateSparseReinterpretMap(patterns, scope); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -87,7 +87,7 @@ struct PreSparsificationRewritePass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populatePreSparsificationRewriting(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -110,7 +110,7 @@ struct SparsificationPass RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); scf::ForOp::getCanonicalizationPatterns(patterns, ctx); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -122,7 +122,7 @@ struct StageSparseOperationsPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateStageSparseOperationsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -141,7 +141,7 @@ struct LowerSparseOpsToForeachPass RewritePatternSet patterns(ctx); populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary, enableConvert); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -154,7 +154,7 @@ struct LowerForeachToSCFPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateLowerForeachToSCFPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -329,7 +329,7 @@ struct SparseBufferRewritePass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateSparseBufferRewriting(patterns, enableBufferInitialization); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -351,7 +351,7 @@ struct SparseVectorizationPass populateSparseVectorizationPatterns( patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); vector::populateVectorToVectorCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -371,7 +371,7 @@ struct SparseGPUCodegenPass populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); else populateSparseGPUCodegenPatterns(patterns, numThreads); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp index 0f5fa61879b71..998b0fb6eb4b7 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -277,7 +277,7 @@ struct FoldTensorSubsetOpsPass final void FoldTensorSubsetOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); tensor::populateFoldTensorSubsetOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr tensor::createFoldTensorSubsetOpsPass() { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index e1400f0c907b2..9299db7e51a01 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -60,7 +60,7 @@ struct TosaLayerwiseConstantFoldPass aggressiveReduceConstant); populateTosaOpsCanonicalizationPatterns(ctx, patterns); - if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + if (applyPatternsGreedily(func, std::move(patterns)).failed()) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index 9c6ee4c62eee5..2a990eed3f681 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -246,7 +246,7 @@ struct TosaMakeBroadcastable patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp index cef903a39e45b..603185e48aa94 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp @@ -42,7 +42,7 @@ struct TosaOptionalDecompositions mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); - if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + if (applyPatternsGreedily(func, std::move(patterns)).failed()) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 1f0f183e29f9a..106a794735090 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -417,7 +417,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( if (target->hasTrait()) { // Op is isolated from above. Apply patterns and also perform region // simplification. - result = applyPatternsAndFoldGreedily(target, frozenPatterns, config); + result = applyPatternsGreedily(target, frozenPatterns, config); } else { // Manually gather list of ops because the other // GreedyPatternRewriteDriver overloads only accepts ops that are isolated @@ -429,7 +429,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( if (target != nestedOp) ops.push_back(nestedOp); }); - result = applyOpPatternsAndFold(ops, frozenPatterns, config); + result = applyOpPatternsGreedily(ops, frozenPatterns, config); } // A failure typically indicates that the pattern application did not diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index bfc05c71f5340..1f6cac2aa6f96 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -286,7 +286,7 @@ struct LowerVectorMaskPass populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); MaskOp::getCanonicalizationPatterns(loweringPatterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 72bf329daaa76..0cafc9cd35517 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -486,7 +486,7 @@ struct LowerVectorMultiReductionPass populateVectorMultiReductionLoweringPatterns(loweringPatterns, this->loweringStrategy); - if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp index 9307e8eb784b5..e3082c55427fe 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp @@ -78,5 +78,5 @@ struct XeGPUFoldAliasOpsPass final void XeGPUFoldAliasOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUFoldAliasOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index b00045a3a41b7..2d2744bfc2732 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -65,7 +65,7 @@ static void applyPatterns(Region ®ion, // because we don't have expectation this reduction will be success or not. GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyOpPatternsAndFold(op, patterns, config); + (void)applyOpPatternsGreedily(op, patterns, config); } if (eraseOpNotInRange) diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index d50019bd6aee5..5f46960507036 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -60,7 +60,7 @@ struct Canonicalizer : public impl::CanonicalizerBase { } void runOnOperation() override { LogicalResult converged = - applyPatternsAndFoldGreedily(getOperation(), *patterns, config); + applyPatternsGreedily(getOperation(), *patterns, config); // Canonicalization is best-effort. Non-convergence is not a pass failure. if (testConvergence && failed(converged)) signalPassFailure(); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index e0d0acd122e26..99f3569b767b1 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements mlir::applyPatternsAndFoldGreedily. +// This file implements mlir::applyPatternsGreedily. // //===----------------------------------------------------------------------===// @@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { // infinite folding loop, as every constant op would be folded to an // Attribute and then immediately be rematerialized as a constant op, which // is then put on the worklist. - if (!op->hasTrait()) { + if (config.fold && !op->hasTrait()) { SmallVector foldResults; if (succeeded(op->fold(foldResults))) { LLVM_DEBUG(logResultWithLine("success", "operation was folded")); @@ -852,13 +852,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) + if (!config.cseConstants || !insertKnownConstant(op)) addToWorklist(op); }); } else { // Add all nested operations to the worklist in preorder. region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) { + if (!config.cseConstants || !insertKnownConstant(op)) { addToWorklist(op); return WalkResult::advance(); } @@ -894,9 +894,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { } LogicalResult -mlir::applyPatternsAndFoldGreedily(Region ®ion, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config, bool *changed) { +mlir::applyPatternsGreedily(Region ®ion, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config, bool *changed) { // The top-level operation must be known to be isolated from above to // prevent performing canonicalizations on operations defined at or above // the region containing 'op'. @@ -1012,7 +1012,7 @@ static Region *findCommonAncestor(ArrayRef ops) { return region; } -LogicalResult mlir::applyOpPatternsAndFold( +LogicalResult mlir::applyOpPatternsGreedily( ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config, bool *changed, bool *allErased) { if (ops.empty()) { diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index c208716891ef1..6474c59595eb4 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -296,7 +296,7 @@ OneToNConversionPattern::matchAndRewrite(Operation *op, namespace mlir { // This function applies the provided patterns using -// `applyPatternsAndFoldGreedily` and then replaces all newly inserted +// `applyPatternsGreedily` and then replaces all newly inserted // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts // from target to source types inserted by a `OneToNConversionPattern` normally // fold away with the "forward" casts from source to target types inserted by @@ -317,7 +317,7 @@ applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, #endif // NDEBUG // Apply provided conversion patterns. - if (failed(applyPatternsAndFoldGreedily(op, patterns))) { + if (failed(applyPatternsGreedily(op, patterns))) { emitError(op->getLoc()) << "failed to apply conversion patterns"; return failure(); } diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir index 3c0cd15dc6c51..86ed6c25a227a 100644 --- a/mlir/test/Transforms/test-operation-folder.mlir +++ b/mlir/test/Transforms/test-operation-folder.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s // RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s +// RUN: mlir-opt -test-greedy-patterns='cse-constants=false' %s | FileCheck %s --check-prefix=NOCSE +// RUN: mlir-opt -test-greedy-patterns='fold=false' %s | FileCheck %s --check-prefix=NOFOLD func.func @foo() -> i32 { %c42 = arith.constant 42 : i32 @@ -25,7 +27,8 @@ func.func @test_fold_before_previously_folded_op() -> (i32, i32) { } func.func @test_dont_reorder_constants() -> (i32, i32, i32) { - // Test that we don't reorder existing constants during folding if it isn't necessary. + // Test that we don't reorder existing constants during folding if it isn't + // necessary. // CHECK: %[[CST:.+]] = arith.constant 1 // CHECK-NEXT: %[[CST:.+]] = arith.constant 2 // CHECK-NEXT: %[[CST:.+]] = arith.constant 3 @@ -34,3 +37,46 @@ func.func @test_dont_reorder_constants() -> (i32, i32, i32) { %2 = arith.constant 3 : i32 return %0, %1, %2 : i32, i32, i32 } + +// CHECK-LABEL: test_fold_nofold_nocse +// NOCSE-LABEL: test_fold_nofold_nocse +// NOFOLD-LABEL: test_fold_nofold_nocse +func.func @test_fold_nofold_nocse() -> (i32, i32, i32, i32, i32, i32) { + // Test either not folding or deduping constants. + + // Testing folding. There should be only 4 constants here. + // CHECK-NOT: arith.constant + // CHECK-DAG: %[[CST:.+]] = arith.constant 0 + // CHECK-DAG: %[[CST:.+]] = arith.constant 1 + // CHECK-DAG: %[[CST:.+]] = arith.constant 2 + // CHECK-DAG: %[[CST:.+]] = arith.constant 3 + // CHECK-NOT: arith.constant + // CHECK-NEXT: return + + // Testing not-CSE'ing. In this case we have the 3 original constants and 3 + // produced by folding. + // NOCSE-DAG: arith.constant 0 : i32 + // NOCSE-DAG: arith.constant 1 : i32 + // NOCSE-DAG: arith.constant 2 : i32 + // NOCSE-DAG: arith.constant 1 : i32 + // NOCSE-DAG: arith.constant 2 : i32 + // NOCSE-DAG: arith.constant 3 : i32 + // NOCSE-NEXT: return + + // Testing not folding. In this case we just have the original constants. + // NOFOLD-DAG: %[[CST:.+]] = arith.constant 0 + // NOFOLD-DAG: %[[CST:.+]] = arith.constant 1 + // NOFOLD-DAG: %[[CST:.+]] = arith.constant 2 + // NOFOLD: arith.addi + // NOFOLD: arith.addi + // NOFOLD: arith.addi + + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = arith.addi %c0, %c1 : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.addi %c2, %c1 : i32 + return %0, %1, %2, %c0, %c1, %c2 : i32, i32, i32, i32, i32, i32 +} + diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp index e17fe12b9088b..1e45ab57ebcc7 100644 --- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp @@ -248,7 +248,7 @@ struct TestMathToVCIX RewritePatternSet patterns(ctx); patterns.add( ctx); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp index 1864d2f7f5036..d49b4e391a68f 100644 --- a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp +++ b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp @@ -41,7 +41,7 @@ struct TestVectorReductionToSPIRVDotProd void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorReductionToSPIRVDotProductPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index b418a457473a8..404f34ebee17a 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -136,7 +136,7 @@ void TestAffineDataCopy::runOnOperation() { } GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config); + (void)applyOpPatternsGreedily(copyOps, std::move(patterns), config); } namespace mlir { diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp index f6bfd9f858284..03c80b601a347 100644 --- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp +++ b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp @@ -47,7 +47,7 @@ void TestLowerToArmNeon::runOnOperation() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); populateLowerContractionToSMMLAPatternPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp index 74d057c0b7b6c..a49d304baf5c6 100644 --- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp +++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp @@ -38,7 +38,7 @@ struct TestGpuRewritePass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateGpuRewritePatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -85,7 +85,7 @@ struct TestGpuSubgroupReduceLoweringPass patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32); } - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index 4cf2460150d14..d0700f9a4f1a4 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -34,8 +34,7 @@ struct TestDataLayoutPropagationPass RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( patterns, [](OpOperand *opOperand) { return true; }); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp index 311244aeffb90..0143a27bfe843 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -43,8 +43,8 @@ struct TestLinalgDecomposeOps RewritePatternSet decompositionPatterns(context); linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns, removeDeadArgsAndResults); - if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(decompositionPatterns)))) { + if (failed(applyPatternsGreedily(getOperation(), + std::move(decompositionPatterns)))) { return signalPassFailure(); } } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 7f68f4aec3a10..e4883e47f2063 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -155,8 +155,8 @@ struct TestLinalgElementwiseFusion RewritePatternSet fusionPatterns(context); auto controlFn = [](OpOperand *operand) { return true; }; linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -166,8 +166,8 @@ struct TestLinalgElementwiseFusion linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, setFusedOpOperandLimit<4>); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -176,8 +176,8 @@ struct TestLinalgElementwiseFusion RewritePatternSet fusionPatterns(context); linalg::populateFoldReshapeOpsByExpansionPatterns( fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -212,8 +212,8 @@ struct TestLinalgElementwiseFusion linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, controlReshapeFusionFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -222,8 +222,7 @@ struct TestLinalgElementwiseFusion RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns( patterns, [](OpOperand * /*fusedOperand */) { return true; }); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } @@ -239,8 +238,7 @@ struct TestLinalgElementwiseFusion return true; }; linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } @@ -248,8 +246,7 @@ struct TestLinalgElementwiseFusion if (fuseMultiUseProducer) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } @@ -265,8 +262,7 @@ struct TestLinalgElementwiseFusion }; RewritePatternSet patterns(context); linalg::populateCollapseDimensions(patterns, collapseFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 2d8ee2f9bb6e3..81e7eedabd5d1 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -84,7 +84,7 @@ struct TestLinalgGreedyFusion pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); do { - (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); + (void)applyPatternsGreedily(getOperation(), frozenPatterns); if (failed(runPipeline(pm, getOperation()))) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp index 8b455d7d68c30..750ba6b5d9872 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp @@ -49,8 +49,7 @@ struct TestLinalgRankReduceContractionOps RewritePatternSet patterns(context); linalg::populateContractionOpRankReducingPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 25aec75c3c14a..fa2a27dcfa991 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -147,14 +147,14 @@ static void applyPatterns(func::FuncOp funcOp) { //===--------------------------------------------------------------------===// patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { RewritePatternSet forwardPattern(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); + (void)applyPatternsGreedily(funcOp, std::move(forwardPattern)); } static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { @@ -163,68 +163,68 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { patterns.add(ctx); populatePadOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposePadPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateBubbleUpExtractSliceOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateSwapExtractSliceWithFillPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateEraseUnusedOperandsAndResultsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateEraseUnnecessaryInputsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyWinogradConv2D(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposeWinogradOps(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateDecomposeWinogradOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } /// Apply transformations specified as patterns. diff --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp index 073e0d8d4e143..b927767038a9e 100644 --- a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp @@ -36,8 +36,7 @@ struct TestPadFusionPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp index 084a592215241..42491d4c716c9 100644 --- a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp +++ b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp @@ -40,7 +40,7 @@ struct TestMathAlgebraicSimplificationPass void TestMathAlgebraicSimplificationPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateMathAlgebraicSimplificationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp index 69af2a08b97bd..0139eabba373f 100644 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -53,7 +53,7 @@ void TestExpandMathPass::runOnOperation() { populateExpandRoundFPattern(patterns); populateExpandRoundEvenPattern(patterns); populateExpandRsqrtPattern(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp index 8a01ac509c30e..9fdd200e2b2c9 100644 --- a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp @@ -59,7 +59,7 @@ void TestMathPolynomialApproximationPass::runOnOperation() { MathPolynomialApproximationOptions approxOptions; approxOptions.enableAvx2 = enableAvx2; populateMathPolynomialApproximationPatterns(patterns, approxOptions); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp index 02a9dbbe263f8..08d22ab59f94b 100644 --- a/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp @@ -38,7 +38,7 @@ void TestComposeSubViewPass::getDependentDialects( void TestComposeSubViewPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateComposeSubViewPatterns(patterns, &getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } } // namespace diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp index 1f836be1ae7ac..dbae93b380f2b 100644 --- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp @@ -26,9 +26,9 @@ struct TestAllSliceOpLoweringPass SymbolTableCollection symbolTableCollection; mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; - assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); + assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { mesh::registerAllSliceOpLoweringDialects(registry); @@ -51,9 +51,9 @@ struct TestMultiIndexOpLoweringPass mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection); LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; - assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); + assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { mesh::registerProcessMultiIndexOpLoweringDialects(registry); diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp index 98992c4cc11f9..102e64de4bd1f 100644 --- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp @@ -97,8 +97,8 @@ struct TestMeshReshardingPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation().getOperation(), + std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp index 512b16af64c94..01e196d29f7a5 100644 --- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp @@ -34,7 +34,7 @@ void TestMeshSimplificationsPass::runOnOperation() { SymbolTableCollection symbolTableCollection; mesh::populateSimplificationPatterns(patterns, symbolTableCollection); [[maybe_unused]] LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyPatternsGreedily(getOperation(), std::move(patterns)); assert(succeeded(status) && "Rewrite patters application did not converge."); } diff --git a/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp index 8ca29257b8120..0099dc8caf427 100644 --- a/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp @@ -60,7 +60,7 @@ struct TestMmaSyncF32ToTF32Patterns RewritePatternSet patterns(&getContext()); populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index a3be1f94fa28a..b4f3fa30f8ab5 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -226,7 +226,7 @@ struct TestSCFPipeliningPass options.peelEpilogue = false; } scf::populateSCFLoopPipeliningPatterns(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); getOperation().walk([](Operation *op) { // Clean up the markers. op->removeAttr(kTestPipeliningStageMarker); diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp index 7e51d67702b05..856cde19edd52 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp @@ -59,7 +59,7 @@ struct TestWrapWhileLoopInZeroTripCheckPass } else { RewritePatternSet patterns(context); scf::populateSCFRotateWhileLoopPatterns(patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } diff --git a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp index 468bc0ca78489..cf123fe280242 100644 --- a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp +++ b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp @@ -34,7 +34,7 @@ struct TestSCFUpliftWhileToFor MLIRContext *ctx = op->getContext(); RewritePatternSet patterns(ctx); scf::populateUpliftWhileToForPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 34de600132f5d..173bfd8955f2b 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -104,19 +104,19 @@ struct TestTensorTransforms static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateReassociativeReshapeFoldingPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyBubbleUpExpandShapePatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateBubbleUpExpandShapePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateFoldIntoPackAndUnpackPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { @@ -132,26 +132,26 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { }; tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applySimplifyPackUnpackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSimplifyPackAndUnpackPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } namespace { @@ -293,7 +293,7 @@ applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, else patterns.add( rootOp->getContext()); - return applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + return applyPatternsGreedily(rootOp, std::move(patterns)); } namespace { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 8a0bc597c56be..ce2820b80a945 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -388,8 +388,9 @@ struct TestGreedyPatternDriver GreedyRewriteConfig config; config.useTopDownTraversal = this->useTopDownTraversal; config.maxIterations = this->maxIterations; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + config.fold = this->fold; + config.cseConstants = this->cseConstants; + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); } Option useTopDownTraversal{ @@ -400,6 +401,11 @@ struct TestGreedyPatternDriver *this, "max-iterations", llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), llvm::cl::init(GreedyRewriteConfig().maxIterations)}; + Option fold{*this, "fold", llvm::cl::desc("Whether to fold"), + llvm::cl::init(GreedyRewriteConfig().fold)}; + Option cseConstants{*this, "cse-constants", + llvm::cl::desc("Whether to CSE constants"), + llvm::cl::init(GreedyRewriteConfig().cseConstants)}; }; struct DumpNotifications : public RewriterBase::Listener { @@ -511,8 +517,8 @@ struct TestStrictPatternDriver // operation will trigger the assertion while processing. bool changed = false; bool allErased = false; - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, - &changed, &allErased); + (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, + &changed, &allErased); Builder b(ctx); getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); getOperation()->setAttr("pattern_driver_all_erased", @@ -2101,7 +2107,7 @@ struct TestSelectiveReplacementPatternDriver MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.add(context); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp index 031e1062dac76..d8763f562cbef 100644 --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -38,8 +38,8 @@ struct TestTraitFolder StringRef getArgument() const final { return "test-trait-folder"; } StringRef getDescription() const final { return "Run trait folding"; } void runOnOperation() override { - (void)applyPatternsAndFoldGreedily(getOperation(), - RewritePatternSet(&getContext())); + (void)applyPatternsGreedily(getOperation(), + RewritePatternSet(&getContext())); } }; } // namespace diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index e5a3e2b6fccaa..ac904c3e01c93 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -197,7 +197,7 @@ void TosaTestQuantUtilAPI::runOnOperation() { patterns.add(ctx); patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f67a24755ac09..74838bc0ca2fb 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -73,7 +73,7 @@ struct TestVectorToVectorLowering populateVectorToVectorCanonicalizationPatterns(patterns); populateBubbleVectorBitCastOpPatterns(patterns); populateCastAwayVectorLeadingOneDimPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } private: @@ -137,7 +137,7 @@ struct TestVectorContractionPrepareForMMTLowering MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -223,7 +223,7 @@ struct TestVectorUnrollingPatterns })); } populateVectorToVectorCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } ListOption unrollOrder{*this, "unroll-order", @@ -283,7 +283,7 @@ struct TestVectorTransferUnrollingPatterns } populateVectorUnrollPatterns(patterns, opts); populateVectorToVectorCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } Option reverseUnrollOrder{ @@ -326,7 +326,7 @@ struct TestScalarVectorTransferLoweringPatterns RewritePatternSet patterns(ctx); vector::populateScalarVectorTransferLoweringPatterns( patterns, /*benefit=*/1, allowMultipleUses.getValue()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -370,7 +370,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -395,7 +395,7 @@ struct TestVectorSinkPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateSinkVectorOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -415,7 +415,7 @@ struct TestVectorReduceToContractPatternsPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorReductionToContractPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -434,7 +434,7 @@ struct TestVectorChainedReductionFoldingPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateChainedVectorReductionFoldingPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -455,7 +455,7 @@ struct TestVectorBreakDownReductionPatterns RewritePatternSet patterns(&getContext()); populateBreakDownVectorReductionPatterns(patterns, /*maxNumElementsToExtract=*/2); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -496,7 +496,7 @@ struct TestFlattenVectorTransferPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -512,7 +512,7 @@ struct TestVectorScanLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorScanLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -662,18 +662,18 @@ struct TestVectorDistribution /*readBenefit=*/0); vector::populateDistributeReduction(patterns, warpReduction, 1); populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } else if (distributeTransferWriteOps) { RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn, maxTransferWriteElements); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } else if (propagateDistribution) { RewritePatternSet patterns(ctx); vector::populatePropagateWarpVectorDistributionPatterns( patterns, distributionFn, shuffleFn); vector::populateDistributeReduction(patterns, warpReduction); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } WarpExecuteOnLane0LoweringOptions options; options.warpAllocationFn = allocateGlobalSharedMemory; @@ -684,7 +684,7 @@ struct TestVectorDistribution // Test on one pattern in isolation. if (warpOpToSCF) { populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); return; } } @@ -706,7 +706,7 @@ struct TestVectorExtractStridedSliceLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -726,7 +726,7 @@ struct TestVectorBreakDownBitCast populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { return op.getSourceVectorType().getShape().back() > 4; }); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -782,7 +782,7 @@ struct TestVectorGatherLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorGatherLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -809,7 +809,7 @@ struct TestFoldArithExtensionIntoVectorContractPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateFoldArithExtensionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -834,7 +834,7 @@ struct TestVectorEmulateMaskedLoadStore final void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorMaskedLoadStoreEmulationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index 77aa30f847dcd..7b96bf5e28d32 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -161,8 +161,8 @@ struct TestPDLByteCodePass patternList.add(std::move(pdlPattern)); // Invoke the pattern driver with the provided patterns. - (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), - std::move(patternList)); + (void)applyPatternsGreedily(irModule.getBodyRegion(), + std::move(patternList)); } }; } // namespace diff --git a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp index db45d0eadf818..f6b2b2b1c683f 100644 --- a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp +++ b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp @@ -39,7 +39,7 @@ struct TestPDLLPass : public PassWrapper> { void runOnOperation() final { // Invoke the pattern driver with the provided patterns. - (void)applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)applyPatternsGreedily(getOperation(), patterns); } FrozenRewritePatternSet patterns; diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp index 2ec0334ae0d05..5ea35759bb729 100644 --- a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -36,7 +36,7 @@ struct CommutativityUtils RewritePatternSet patterns(context); populateCommutativityUtilsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp index 82fa6cdb68d23..4e0213c0e4cfd 100644 --- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp +++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp @@ -123,7 +123,7 @@ void TestMakeIsolatedFromAbovePass::runOnOperation() { if (simple) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; @@ -132,7 +132,7 @@ void TestMakeIsolatedFromAbovePass::runOnOperation() { if (cloneOpsWithNoOperands) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; @@ -141,7 +141,7 @@ void TestMakeIsolatedFromAbovePass::runOnOperation() { if (cloneOpsWithOperands) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return;