-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][IR] Move match
and rewrite
functions into separate class
#129861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Move match
and rewrite
functions into separate class
#129861
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-backend-amdgpu Author: Matthias Springer (matthias-springer) ChangesThe vast majority of rewrite / conversion patterns uses a combined This PR optimizes the code base for the most common case where users implement a combined Details:
Patch is 27.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129861.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+ ConvertOpToLLVMPattern<SourceOp>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
benefit) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}
private:
- using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
// RewritePattern
//===----------------------------------------------------------------------===//
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-/// * Multi-step RewritePattern with "match" and "rewrite"
-/// - By overloading the "match" and "rewrite" functions, the user can
-/// separate the concerns of matching and rewriting.
-/// * Single-step RewritePattern with "matchAndRewrite"
-/// - By overloading the "matchAndRewrite" function, the user can perform
-/// the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
- virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
- /// builder. If an unexpected error is encountered (an internal
- /// compiler error), it is emitted through the normal MLIR diagnostic
- /// hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const = 0;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
- virtual LogicalResult match(Operation *op) const;
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
- /// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind(). If successful, this
- /// function will automatically perform the rewrite.
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+ LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const final {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+ using OperationT = Operation *;
+ using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+ virtual ~RewritePattern() = default;
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). If successful, this
+ /// function will automatically perform the rewrite.
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const = 0;
/// This method provides a convenient interface for creating and initializing
/// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
/// class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+ using OperationT = SourceOp;
using RewritePattern::RewritePattern;
- /// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, PatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
+ /// Wrapper around the RewritePattern method that passes the derived op type.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
- /// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
+ /// Method that operates on the SourceOp type. Must be overridden by the
+ /// derived pattern class.
virtual LogicalResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const {
- if (succeeded(match(op))) {
- rewrite(op, rewriter);
- return success();
- }
- return failure();
- }
+ PatternRewriter &rewriter) const = 0;
};
} // namespace detail
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
// Conversion Patterns
//===----------------------------------------------------------------------===//
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
+
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // One of the two `rewrite` functions must be implemented.
+ llvm_unreachable("rewrite is not implemented");
+ }
+
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ if constexpr (std::is_same<typename PatternT::OpAdaptor,
+ ArrayRef<Value>>::value) {
+ rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+ } else {
+ SmallVector<Value> oneToOneOperands =
+ PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+ rewriter);
+ }
+ }
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind().
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+ }
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ if (succeeded(match(op))) {
+ rewrite(op, adaptor, rewriter);
+ return success();
+ }
+ return failure();
+ }
+};
+} // namespace detail
+
/// Base class for the conversion patterns. This pattern class enables type
/// conversions, and other uses specific to the conversion framework. As such,
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
- /// Hook for derived classes to implement rewriting. `op` is the (first)
- /// operation matched by the pattern, `operands` is a list of the rewritten
- /// operand values that are passed to `op`, `rewriter` can be used to emit the
- /// new operations. This function should not fail. If some specific cases of
- /// the operation are not supported, these cases should not be matched.
- virtual void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("unimplemented rewrite");
- }
- virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
+ using OperationT = Operation *;
+ using OpAdaptor = ArrayRef<Value>;
+ using OneToNOpAdaptor = ArrayRef<ValueRange>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
/// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
-
-private:
- using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+ : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+ using SplitMatchAndRewrite::SplitMatchAndRewrite;
Chipset chipset;
ExtFOnFloat8...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe vast majority of rewrite / conversion patterns uses a combined This PR optimizes the code base for the most common case where users implement a combined Details:
Patch is 27.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129861.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+ ConvertOpToLLVMPattern<SourceOp>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
benefit) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}
private:
- using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
// RewritePattern
//===----------------------------------------------------------------------===//
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-/// * Multi-step RewritePattern with "match" and "rewrite"
-/// - By overloading the "match" and "rewrite" functions, the user can
-/// separate the concerns of matching and rewriting.
-/// * Single-step RewritePattern with "matchAndRewrite"
-/// - By overloading the "matchAndRewrite" function, the user can perform
-/// the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
- virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
- /// builder. If an unexpected error is encountered (an internal
- /// compiler error), it is emitted through the normal MLIR diagnostic
- /// hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const = 0;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
- virtual LogicalResult match(Operation *op) const;
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
- /// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind(). If successful, this
- /// function will automatically perform the rewrite.
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+ LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const final {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+ using OperationT = Operation *;
+ using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+ virtual ~RewritePattern() = default;
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). If successful, this
+ /// function will automatically perform the rewrite.
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const = 0;
/// This method provides a convenient interface for creating and initializing
/// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
/// class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+ using OperationT = SourceOp;
using RewritePattern::RewritePattern;
- /// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, PatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
+ /// Wrapper around the RewritePattern method that passes the derived op type.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
- /// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
+ /// Method that operates on the SourceOp type. Must be overridden by the
+ /// derived pattern class.
virtual LogicalResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const {
- if (succeeded(match(op))) {
- rewrite(op, rewriter);
- return success();
- }
- return failure();
- }
+ PatternRewriter &rewriter) const = 0;
};
} // namespace detail
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
// Conversion Patterns
//===----------------------------------------------------------------------===//
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
+
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // One of the two `rewrite` functions must be implemented.
+ llvm_unreachable("rewrite is not implemented");
+ }
+
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ if constexpr (std::is_same<typename PatternT::OpAdaptor,
+ ArrayRef<Value>>::value) {
+ rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+ } else {
+ SmallVector<Value> oneToOneOperands =
+ PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+ rewriter);
+ }
+ }
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind().
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+ }
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ if (succeeded(match(op))) {
+ rewrite(op, adaptor, rewriter);
+ return success();
+ }
+ return failure();
+ }
+};
+} // namespace detail
+
/// Base class for the conversion patterns. This pattern class enables type
/// conversions, and other uses specific to the conversion framework. As such,
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
- /// Hook for derived classes to implement rewriting. `op` is the (first)
- /// operation matched by the pattern, `operands` is a list of the rewritten
- /// operand values that are passed to `op`, `rewriter` can be used to emit the
- /// new operations. This function should not fail. If some specific cases of
- /// the operation are not supported, these cases should not be matched.
- virtual void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("unimplemented rewrite");
- }
- virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
+ using OperationT = Operation *;
+ using OpAdaptor = ArrayRef<Value>;
+ using OneToNOpAdaptor = ArrayRef<ValueRange>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
/// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
-
-private:
- using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+ : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+ using SplitMatchAndRewrite::SplitMatchAndRewrite;
Chipset chipset;
ExtFOnFloat8...
[truncated]
|
@llvm/pr-subscribers-mlir-arith Author: Matthias Springer (matthias-springer) ChangesThe vast majority of rewrite / conversion patterns uses a combined This PR optimizes the code base for the most common case where users implement a combined Details:
Patch is 27.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129861.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+ ConvertOpToLLVMPattern<SourceOp>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
benefit) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}
private:
- using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
// RewritePattern
//===----------------------------------------------------------------------===//
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-/// * Multi-step RewritePattern with "match" and "rewrite"
-/// - By overloading the "match" and "rewrite" functions, the user can
-/// separate the concerns of matching and rewriting.
-/// * Single-step RewritePattern with "matchAndRewrite"
-/// - By overloading the "matchAndRewrite" function, the user can perform
-/// the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
- virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
- /// builder. If an unexpected error is encountered (an internal
- /// compiler error), it is emitted through the normal MLIR diagnostic
- /// hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const = 0;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
- virtual LogicalResult match(Operation *op) const;
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
- /// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind(). If successful, this
- /// function will automatically perform the rewrite.
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+ LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const final {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+ using OperationT = Operation *;
+ using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+ virtual ~RewritePattern() = default;
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). If successful, this
+ /// function will automatically perform the rewrite.
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const = 0;
/// This method provides a convenient interface for creating and initializing
/// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
/// class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+ using OperationT = SourceOp;
using RewritePattern::RewritePattern;
- /// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, PatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
+ /// Wrapper around the RewritePattern method that passes the derived op type.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
- /// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
+ /// Method that operates on the SourceOp type. Must be overridden by the
+ /// derived pattern class.
virtual LogicalResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const {
- if (succeeded(match(op))) {
- rewrite(op, rewriter);
- return success();
- }
- return failure();
- }
+ PatternRewriter &rewriter) const = 0;
};
} // namespace detail
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
// Conversion Patterns
//===----------------------------------------------------------------------===//
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
+
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // One of the two `rewrite` functions must be implemented.
+ llvm_unreachable("rewrite is not implemented");
+ }
+
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ if constexpr (std::is_same<typename PatternT::OpAdaptor,
+ ArrayRef<Value>>::value) {
+ rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+ } else {
+ SmallVector<Value> oneToOneOperands =
+ PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+ rewriter);
+ }
+ }
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind().
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+ }
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ if (succeeded(match(op))) {
+ rewrite(op, adaptor, rewriter);
+ return success();
+ }
+ return failure();
+ }
+};
+} // namespace detail
+
/// Base class for the conversion patterns. This pattern class enables type
/// conversions, and other uses specific to the conversion framework. As such,
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
- /// Hook for derived classes to implement rewriting. `op` is the (first)
- /// operation matched by the pattern, `operands` is a list of the rewritten
- /// operand values that are passed to `op`, `rewriter` can be used to emit the
- /// new operations. This function should not fail. If some specific cases of
- /// the operation are not supported, these cases should not be matched.
- virtual void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("unimplemented rewrite");
- }
- virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
+ using OperationT = Operation *;
+ using OpAdaptor = ArrayRef<Value>;
+ using OneToNOpAdaptor = ArrayRef<ValueRange>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
/// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
-
-private:
- using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+ : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+ using SplitMatchAndRewrite::SplitMatchAndRewrite;
Chipset chipset;
ExtFOnFloat8...
[truncated]
|
I can confirm that the gcc (14) warning about hidden functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the removal of those unreachable instances in the base, this looks like good cleanup.
So for folks with the existing split one, it's like a 3 line change to use the split one instead?
2fb9e11
to
5e025f8
Compare
That's right. Users should derive from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me. This does raise a question though, what actually needs separate match and rewrite these days? Do we even need to carry around this code? Seems like we could just deprecate this and remove it (it's mostly legacy from the early days anyways).
Co-authored-by: River Riddle <[email protected]>
Deprecate the `match` and `rewrite` functions. They mainly exist for historic reasons. This PR also updates all remaining uses of in the MLIR codebase. This is addressing a [comment](#129861 (review)) on an earlier PR. Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon, update your patterns to use `matchAndRewrite` instead of separate `match` / `rewrite`. --------- Co-authored-by: Jakub Kuderski <[email protected]>
…031) Deprecate the `match` and `rewrite` functions. They mainly exist for historic reasons. This PR also updates all remaining uses of in the MLIR codebase. This is addressing a [comment](llvm/llvm-project#129861 (review)) on an earlier PR. Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon, update your patterns to use `matchAndRewrite` instead of separate `match` / `rewrite`. --------- Co-authored-by: Jakub Kuderski <[email protected]>
…lvm#129861) The vast majority of rewrite / conversion patterns uses a combined `matchAndRewrite` instead of separate `match` and `rewrite` functions. This PR optimizes the code base for the most common case where users implement a combined `matchAndRewrite`. There are no longer any `match` and `rewrite` functions in `RewritePattern`, `ConversionPattern` and their derived classes. Instead, there is a `SplitMatchAndRewriteImpl` class that implements `matchAndRewrite` in terms of `match` and `rewrite`. Details: * The `RewritePattern` and `ConversionPattern` classes are simpler (fewer functions). Especially the `ConversionPattern` class, which now has 5 fewer functions. (There were various `rewrite` overloads to account for 1:1 / 1:N patterns.) * There is a new class `SplitMatchAndRewriteImpl` that derives from `RewritePattern` / `OpRewritePatern` / ..., along with a type alias `RewritePattern::SplitMatchAndRewrite` for convenience. * Fewer `llvm_unreachable` are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement `rewrite` or `matchAndRewrite`, etc.) * This PR may also improve the number of [`-Woverload-virtual` warnings](https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933) that are produced by GCC. (To be confirmed...) Note for LLVM integration: Patterns with separate `match` / `rewrite` implementations, must derive from `X::SplitMatchAndRewrite` instead of `X`. --------- Co-authored-by: River Riddle <[email protected]>
Deprecate the `match` and `rewrite` functions. They mainly exist for historic reasons. This PR also updates all remaining uses of in the MLIR codebase. This is addressing a [comment](llvm#129861 (review)) on an earlier PR. Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon, update your patterns to use `matchAndRewrite` instead of separate `match` / `rewrite`. --------- Co-authored-by: Jakub Kuderski <[email protected]>
It just has one `matchAndRewrite`. (This commit just removes the `override` keyword on the individual `match` and `rewrite`) llvm/llvm-project#129861
It just has one `matchAndRewrite`. (This commit just removes the `override` keyword on the individual `match` and `rewrite`) llvm/llvm-project#129861
**Context:** We update the llvm version tagged by jax 0.6.0: ``` mhlo=617a9361d186199480c080c9e8c474a5e30c22d1 llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158 ``` We also update Enzyme to the latest version, which is 0.0.180, at commit `db0181320d6e425ee963bd496ed0d8dbb615be18` **Description of the Change:** Firstly, jax recently moved from the Google github organization to its own jax-ml organization. This means the urls, and the retrieval method for the underlying llvm and mhlo git commit tags, needs to be updated. (Thanks @mehrdad2m !) Now on to the actual changes. I will list the changes in increasing complexity. 1. The new enzyme cmake target is `EnzymeStatic-21` (from 20) 2. Enzyme works with a later llvm then our target, so it has some llvm intrinsics unknown to the one we are targeting. We patch them away. They do not concern us since they are all intrinsics for nvidia backends. 3. `applyPatternsAndFoldGreedily` is removed. Drop-in replacement is `applyPatternsGreedily`. llvm/llvm-project#104649, llvm/llvm-project#126701 4. ops with `CallOpInterface` must have two new optional attributes `arg_attrs` and `res_attrs` llvm/llvm-project#123176 5. `CallInterfaceCallable` objects now must be directly casted to the callee `SymbolRefAttr`, i.e. `callee.get<SymbolRefAttr>()` -> `cast<SymbolRefAttr>(callee)` llvm/llvm-project@35e8989 6. The `lookupOrCreateFn` family of functions now return `FailureOr<funcop>` instead of just `funcop`, so a `.value()` needs to be used to retrieve the underlying `funcop`. llvm/llvm-project@e84f6b6 7. The cpp api for `OneShotBufferizePassOptions` no longer needs complicated lambdas for the type converter options. They can be set with the `mlir::bufferization::LayoutMapOption::IdentityLayoutMap` options directly. 8. The individual `match` and `rewrite` methods in pattern rewrites are removed. Use the two-in-one `matchAndRewrite` instead. llvm/llvm-project#129861 9. For rewrite patterns with 1-to-N convertions, a new `macthAndRewrite` overload with `OneToNOpAdaptor` must be used. For us, this is only the `catalyst.list*` ops. llvm/llvm-project#116470 10. The lowering of `cf::AssertOp` to llvm was split from the overall`--covert-cf-to-llvm` pass. We need to manually call this separate pattern for cf.assert duriing quantum to llvm dialect lowering, where we also convert cf to llvm. https://github.com/llvm/llvm-project/pull/120431/files 11. The new mhlo depends on a [shardy](https://github.com/openxla/shardy) dialect. Shardy is built with bazel, not cmake. Building shardy ourselves would be very difficult (not having bazel in our build ecosystem is a hard constraint, cc @mlxd ), and also not necessary (we just use mhlo for their "standard" passes). We thus patch out all shardy components. 12. Three necessary passes were removed in mhlo: `mhlo-legalize-control-flow`, `mhlo-legalize-to-std`, `hlo-legalize-sort` tensorflow/mlir-hlo@4a640be#diff-ef0d7e30da19a396ba036405a9ef636f8b1be194618b0a90f4602671fc2ef34d tensorflow/mlir-hlo@2a5e267#diff-f8c7cb07b43593403e00e0dbf9983f0186b4eb70368cc99af3b924061f1ea46f - Alongside the removal of `mhlo-legalize-to-std`, the cmake target `MhloToStandard` was removed too. We simply patch them back for now. **For the above two points, note that there will be an overall migration to the stablehlo repo, as mhlo is sunseting. Therefore, spending too much time on this isn't necessary, so we just patch.** 13. The new pattern applicator (`applyPatternsGreedily`) is more aggressive in dead code elimination, and is eliminating dead `Value`s in the adjoint gradient method. The `nodealloc` function we generate for adjoint gradient lowering used to only return the qreg, not the expval result. This causes the expval op to be eliminated since it has no users. This further causes wrong gradient results, since the entire program, all ops included (regardless of dead or not), impacts the gradient through chain rule. To avoid this, we return the expval result as well. In doing this, we implicitly assume that differentiated qnodes can only return expval. Although this assumption is true and also restricted by frontend, ideally we should not have it hard coded. We leave this as a TODO for a future feature. 14. The old `--buffer-deallocation` pass is removed. Intended replacement is `--buffer-deallocation-pipeline`. This migration is very complicated. We simply add back the old buffer deallocation pass in the catalyst dialect as a util for now. We will revisit this in #1778 . mlir lit test updates: 1. `bufferization.to_tensor/memref` updated assembly format 2. gradient adjoint lowering test returns both qreg and expval 3. Some inverse unrealized conversion cast pairs are canceled by the new pattern rewriter. 4. `llvm.mlir.undef` is deprecated, use `llvm.mlir.poison` instead. llvm/llvm-project#125629 **Benefits:** Up to date with upstream versions. [sc-92017] --------- Co-authored-by: Tzung-Han Juang <[email protected]> Co-authored-by: Ritu Thombre <[email protected]> Co-authored-by: Mehrdad Malekmohammadi <[email protected]> Co-authored-by: Mehrdad Malek <[email protected]> Co-authored-by: David Ittah <[email protected]> Co-authored-by: Joey Carter <[email protected]>
The vast majority of rewrite / conversion patterns uses a combined
matchAndRewrite
instead of separatematch
andrewrite
functions.This PR optimizes the code base for the most common case where users implement a combined
matchAndRewrite
. There are no longer anymatch
andrewrite
functions inRewritePattern
,ConversionPattern
and their derived classes. Instead, there is aSplitMatchAndRewriteImpl
class that implementsmatchAndRewrite
in terms ofmatch
andrewrite
.Details:
RewritePattern
andConversionPattern
classes are simpler (fewer functions). Especially theConversionPattern
class, which now has 5 fewer functions. (There were variousrewrite
overloads to account for 1:1 / 1:N patterns.)SplitMatchAndRewriteImpl
that derives fromRewritePattern
/OpRewritePatern
/ ..., along with a type aliasRewritePattern::SplitMatchAndRewrite
for convenience.llvm_unreachable
are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implementrewrite
ormatchAndRewrite
, etc.)-Woverload-virtual
warnings that are produced by GCC. (To be confirmed...)Note for LLVM integration: Patterns with separate
match
/rewrite
implementations, must derive fromX::SplitMatchAndRewrite
instead ofX
.