Skip to content

Commit 98341df

Browse files
[mlir][Transform] Add a transform.match.operation_empty op to allow s… (#68319)
…pecifying negative conditions In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.
1 parent e91a4be commit 98341df

File tree

6 files changed

+312
-127
lines changed

6 files changed

+312
-127
lines changed

mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h

+75-20
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,71 @@
1111

1212
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1313
#include "mlir/IR/OpDefinition.h"
14+
#include "llvm/ADT/STLExtras.h"
15+
#include <optional>
16+
#include <type_traits>
1417

1518
namespace mlir {
1619
namespace transform {
1720
class MatchOpInterface;
1821

22+
namespace detail {
23+
/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
24+
/// first operand.
1925
template <typename OpTy>
20-
class SingleOpMatcherOpTrait
21-
: public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
26+
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
27+
TransformResults &results,
28+
TransformState &state) {
29+
if constexpr (std::is_same_v<
30+
typename llvm::function_traits<
31+
decltype(&OpTy::matchOperation)>::template arg_t<0>,
32+
Operation *>) {
33+
return op.matchOperation(nullptr, results, state);
34+
} else {
35+
return op.matchOperation(std::nullopt, results, state);
36+
}
37+
}
38+
} // namespace detail
39+
40+
template <typename OpTy>
41+
class AtMostOneOpMatcherOpTrait
42+
: public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
2243
template <typename T>
2344
using has_get_operand_handle =
2445
decltype(std::declval<T &>().getOperandHandle());
2546
template <typename T>
26-
using has_match_operation = decltype(std::declval<T &>().matchOperation(
47+
using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
2748
std::declval<Operation *>(), std::declval<TransformResults &>(),
2849
std::declval<TransformState &>()));
50+
template <typename T>
51+
using has_match_operation_optional =
52+
decltype(std::declval<T &>().matchOperation(
53+
std::declval<std::optional<Operation *>>(),
54+
std::declval<TransformResults &>(),
55+
std::declval<TransformState &>()));
2956

3057
public:
3158
static LogicalResult verifyTrait(Operation *op) {
3259
static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
33-
"SingleOpMatcherOpTrait expects operation type to have the "
34-
"getOperandHandle() method");
35-
static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
36-
"SingleOpMatcherOpTrait expected operation type to have the "
37-
"matchOperation(Operation *, TransformResults &, "
38-
"TransformState &) method");
60+
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
61+
"operation type to have the getOperandHandle() method");
62+
static_assert(
63+
llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
64+
llvm::is_detected<has_match_operation_optional, OpTy>::value,
65+
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
66+
"type to have either the matchOperation(Operation *, TransformResults "
67+
"&, TransformState &) or the matchOperation(std::optional<Operation*>, "
68+
"TransformResults &, TransformState &) method");
3969

4070
// This must be a dynamic assert because interface registration is dynamic.
41-
assert(isa<MatchOpInterface>(op) &&
42-
"SingleOpMatchOpTrait is only available on operations with "
43-
"MatchOpInterface");
71+
assert(
72+
isa<MatchOpInterface>(op) &&
73+
"AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
74+
"operations with MatchOpInterface");
4475
Value operandHandle = cast<OpTy>(op).getOperandHandle();
4576
if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
46-
return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
77+
return op->emitError() << "AtMostOneOpMatcherOpTrait/"
78+
"SingleOpMatchOpTrait requires the op handle "
4779
"to be of TransformHandleTypeInterface";
4880
}
4981

@@ -55,12 +87,15 @@ class SingleOpMatcherOpTrait
5587
TransformState &state) {
5688
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
5789
auto payload = state.getPayloadOps(operandHandle);
58-
if (!llvm::hasSingleElement(payload)) {
90+
if (!llvm::hasNItemsOrLess(payload, 1)) {
5991
return emitDefiniteFailure(this->getOperation()->getLoc())
60-
<< "SingleOpMatchOpTrait requires the operand handle to point to "
61-
"a single payload op";
92+
<< "AtMostOneOpMatcherOpTrait requires the operand handle to "
93+
"point to at most one payload op";
94+
}
95+
if (payload.empty()) {
96+
return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
97+
results, state);
6298
}
63-
6499
return cast<OpTy>(this->getOperation())
65100
.matchOperation(*payload.begin(), results, state);
66101
}
@@ -72,12 +107,32 @@ class SingleOpMatcherOpTrait
72107
}
73108
};
74109

110+
template <typename OpTy>
111+
class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {
112+
113+
public:
114+
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
115+
TransformResults &results,
116+
TransformState &state) {
117+
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
118+
auto payload = state.getPayloadOps(operandHandle);
119+
if (!llvm::hasSingleElement(payload)) {
120+
return emitDefiniteFailure(this->getOperation()->getLoc())
121+
<< "SingleOpMatchOpTrait requires the operand handle to point to "
122+
"a single payload op";
123+
}
124+
return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
125+
rewriter, results, state);
126+
}
127+
};
128+
75129
template <typename OpTy>
76130
class SingleValueMatcherOpTrait
77131
: public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
78132
public:
79133
static LogicalResult verifyTrait(Operation *op) {
80-
// This must be a dynamic assert because interface registration is dynamic.
134+
// This must be a dynamic assert because interface registration is
135+
// dynamic.
81136
assert(isa<MatchOpInterface>(op) &&
82137
"SingleValueMatchOpTrait is only available on operations with "
83138
"MatchOpInterface");
@@ -98,8 +153,8 @@ class SingleValueMatcherOpTrait
98153
auto payload = state.getPayloadValues(operandHandle);
99154
if (!llvm::hasSingleElement(payload)) {
100155
return emitDefiniteFailure(this->getOperation()->getLoc())
101-
<< "SingleValueMatchOpTrait requires the value handle to point to "
102-
"a single payload value";
156+
<< "SingleValueMatchOpTrait requires the value handle to point "
157+
"to a single payload value";
103158
}
104159

105160
return cast<OpTy>(this->getOperation())

mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td

+19-2
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,28 @@ def MatchOpInterface
1414
let cppNamespace = "::mlir::transform";
1515
}
1616

17+
// Trait for "matcher" transform operations that apply to an operation handle
18+
// associated with at most one payload operation. Checks that it is indeed
19+
// the case and produces a definite failure when it is not. The matching logic
20+
// is implemented in the `matchOperation` function instead of `apply`. The op
21+
// with this trait must provide a `Value getOperandHandle()` function that
22+
// returns the handle to be used for matching.
23+
def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
24+
let cppNamespace = "::mlir::transform";
25+
26+
string extraDeclaration = [{
27+
::mlir::DiagnosedSilenceableFailure matchOperation(
28+
::std::optional<::mlir::Operation *> maybeCurrent,
29+
::mlir::transform::TransformResults &results,
30+
::mlir::transform::TransformState &state);
31+
}];
32+
}
33+
1734
// Trait for "matcher" transform operations that apply to an operation handle
1835
// associated with exactly one payload operation. Checks that it is indeed
1936
// the case and produces a definite failure when it is not. The matching logic
2037
// is implemented in the `matchOperation` function instead of `apply`. The op
21-
// with this trait must provide a `Value getOperandHandle()` function that
38+
// with this trait must provide a `Value getOperandHandle()` function that
2239
// returns the handle to be used for matching.
2340
def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
2441
let cppNamespace = "::mlir::transform";
@@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
3552
// associated with exactly one payload value. Checks that it is indeed
3653
// the case and produces a definite failure when it is not. The matching logic
3754
// is implemented in the `matchValue` function instead of `apply`. The op
38-
// with this trait must provide a `Value getOperandHandle()` function that
55+
// with this trait must provide a `Value getOperandHandle()` function that
3956
// returns the handle to be used for matching.
4057
def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
4158
let cppNamespace = "::mlir::transform";

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

+103-101
Original file line numberDiff line numberDiff line change
@@ -20,107 +20,109 @@ def Transform_Dialect : Dialect {
2020

2121
let hasOperationAttrVerify = 1;
2222
let extraClassDeclaration = [{
23-
/// Name of the attribute attachable to the symbol table operation
24-
/// containing named sequences. This is used to trigger verification.
25-
constexpr const static ::llvm::StringLiteral
26-
kWithNamedSequenceAttrName = "transform.with_named_sequence";
27-
28-
/// Name of the attribute attachable to an operation so it can be
29-
/// identified as root by the default interpreter pass.
30-
constexpr const static ::llvm::StringLiteral
31-
kTargetTagAttrName = "transform.target_tag";
32-
33-
/// Name of the attribute attachable to an operation, indicating that
34-
/// TrackingListener failures should be silenced.
35-
constexpr const static ::llvm::StringLiteral
36-
kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures";
37-
38-
/// Names of the attributes indicating whether an argument of an external
39-
/// transform dialect symbol is consumed or only read.
40-
constexpr const static ::llvm::StringLiteral
41-
kArgConsumedAttrName = "transform.consumed";
42-
constexpr const static ::llvm::StringLiteral
43-
kArgReadOnlyAttrName = "transform.readonly";
44-
45-
template <typename DataTy>
46-
const DataTy &getExtraData() const {
47-
return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
48-
}
49-
50-
/// Parses a type registered by this dialect or one of its extensions.
51-
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
52-
53-
/// Prints a type registered by this dialect or one of its extensions.
54-
void printType(::mlir::Type type,
55-
::mlir::DialectAsmPrinter &printer) const override;
56-
57-
/// Parser callback for an individual type registered by this dialect or
58-
/// its extensions.
59-
using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);
60-
61-
/// Printer callback for an individual type registered by this dialect or
62-
/// its extensions.
63-
using ExtensionTypePrintingHook =
64-
std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;
65-
66-
private:
67-
/// Registers operations specified as template parameters with this
68-
/// dialect. Checks that they implement the required interfaces.
69-
template <typename... OpTys>
70-
void addOperationsChecked() {
71-
(addOperationIfNotRegistered<OpTys>(), ...);
72-
}
73-
template <typename OpTy>
74-
void addOperationIfNotRegistered();
75-
76-
/// Reports a repeated registration error of an op with the given name.
77-
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);
78-
79-
/// Registers the types specified as template parameters with the
80-
/// Transform dialect. Checks that they meet the requirements for
81-
/// Transform IR types.
82-
template <typename... TypeTys>
83-
void addTypesChecked() {
84-
(addTypeIfNotRegistered<TypeTys>(), ...);
85-
}
86-
template <typename Type>
87-
void addTypeIfNotRegistered();
88-
89-
/// Reports a repeated registration error of a type with the given
90-
/// mnemonic.
91-
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
92-
93-
/// Registers dialect types with the context.
94-
void initializeTypes();
95-
96-
// Give extensions access to injection functions.
97-
template <typename, typename...>
98-
friend class TransformDialectExtension;
99-
100-
/// Gets a mutable reference to extra data of the kind specified as
101-
/// template argument. Allocates the data on the first call.
102-
template <typename DataTy>
103-
DataTy &getOrCreateExtraData();
104-
105-
//===----------------------------------------------------------------===//
106-
// Data fields
107-
//===----------------------------------------------------------------===//
108-
109-
/// Additional data associated with and owned by the dialect. Accessible
110-
/// to extensions.
111-
::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
112-
::mlir::transform::detail::TransformDialectDataBase>>
113-
extraData;
114-
115-
/// A map from type mnemonic to its parsing function for the remainder of
116-
/// the syntax. The parser has access to the mnemonic, so it is used for
117-
/// further dispatch.
118-
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;
119-
120-
/// A map from type TypeID to its printing function. No need to do string
121-
/// lookups when the type is fully constructed.
122-
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
123-
typePrintingHooks;
23+
/// Name of the attribute attachable to the symbol table operation
24+
/// containing named sequences. This is used to trigger verification.
25+
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
26+
"transform.with_named_sequence";
27+
28+
/// Name of the attribute attachable to an operation so it can be
29+
/// identified as root by the default interpreter pass.
30+
constexpr const static ::llvm::StringLiteral kTargetTagAttrName =
31+
"transform.target_tag";
32+
33+
/// Name of the attribute attachable to an operation, indicating that
34+
/// TrackingListener failures should be silenced.
35+
constexpr const static ::llvm::StringLiteral
36+
kSilenceTrackingFailuresAttrName =
37+
"transform.silence_tracking_failures";
38+
39+
/// Names of the attributes indicating whether an argument of an external
40+
/// transform dialect symbol is consumed or only read.
41+
constexpr const static ::llvm::StringLiteral kArgConsumedAttrName =
42+
"transform.consumed";
43+
constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
44+
"transform.readonly";
45+
46+
template <typename DataTy>
47+
const DataTy &getExtraData() const {
48+
return *static_cast<const DataTy *>(
49+
extraData.at(::mlir::TypeID::get<DataTy>()).get());
50+
}
51+
52+
/// Parses a type registered by this dialect or one of its extensions.
53+
::mlir::Type parseType(::mlir::DialectAsmParser & parser) const override;
54+
55+
/// Prints a type registered by this dialect or one of its extensions.
56+
void printType(::mlir::Type type, ::mlir::DialectAsmPrinter & printer)
57+
const override;
58+
59+
/// Parser callback for an individual type registered by this dialect or
60+
/// its extensions.
61+
using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);
62+
63+
/// Printer callback for an individual type registered by this dialect or
64+
/// its extensions.
65+
using ExtensionTypePrintingHook =
66+
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
67+
68+
private:
69+
/// Registers operations specified as template parameters with this
70+
/// dialect. Checks that they implement the required interfaces.
71+
template <typename... OpTys>
72+
void addOperationsChecked() {
73+
(addOperationIfNotRegistered<OpTys>(), ...);
74+
}
75+
template <typename OpTy>
76+
void addOperationIfNotRegistered();
77+
78+
/// Reports a repeated registration error of an op with the given name.
79+
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);
80+
81+
/// Registers types specified as template parameters with the Transform
82+
/// dialect. Checks that they meet the requirements for Transform IR types.
83+
template <typename... TypeTys>
84+
void addTypesChecked() {
85+
(addTypeIfNotRegistered<TypeTys>(), ...);
86+
}
87+
template <typename Type>
88+
void addTypeIfNotRegistered();
89+
90+
/// Reports a repeated registration error of a type with the given
91+
/// mnemonic.
92+
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
93+
94+
/// Registers dialect types with the context.
95+
void initializeTypes();
96+
97+
// Give extensions access to injection functions.
98+
template <typename, typename...>
99+
friend class TransformDialectExtension;
100+
101+
/// Gets a mutable reference to extra data of the kind specified as
102+
/// template argument. Allocates the data on the first call.
103+
template <typename DataTy>
104+
DataTy &getOrCreateExtraData();
105+
106+
//===----------------------------------------------------------------===//
107+
// Data fields
108+
//===----------------------------------------------------------------===//
109+
110+
/// Additional data associated with and owned by the dialect. Accessible
111+
/// to extensions.
112+
::llvm::DenseMap<
113+
::mlir::TypeID,
114+
std::unique_ptr<::mlir::transform::detail::TransformDialectDataBase>>
115+
extraData;
116+
117+
/// A map from type mnemonic to its parsing function for the remainder of
118+
/// the syntax. The parser has access to the mnemonic, so it is used for
119+
/// further dispatch.
120+
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;
121+
122+
/// A map from type TypeID to its printing function. No need to do string
123+
/// lookups when the type is fully constructed.
124+
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
125+
typePrintingHooks;
124126
}];
125127
}
126128

0 commit comments

Comments
 (0)