Skip to content

Commit bfa4271

Browse files
committed
Add anyOf and allOf matchers - Need fixing - Add polymorphic matcher
1 parent 928de27 commit bfa4271

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

mlir/lib/Tools/mlir-query/ExtraMatchers.h

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// goal is to move this to include/mlir/IR/Matchers.h after the initial testing
1111
// phase. The matchers in this file are:
1212
//
13-
// - operation(args...): Matches an operation that matches all of the matchers
14-
// in the vector `matchers`.
13+
// - operation(args...): Matches all of the matchers in the vector `matchers`.
1514
//
1615
// - argument(innerMatcher, index): Matches an operation argument that matches
1716
// `innerMatcher` at the given `index`.
@@ -39,11 +38,9 @@ namespace extramatcher {
3938

4039
namespace detail {
4140

42-
// TODO: Rename to AllOf
43-
// OperationMatcher takes a vector of DynMatchers and returns true if all
44-
// DynMatchers match the given operation.
45-
struct OperationMatcher {
46-
OperationMatcher(std::vector<matcher::DynMatcher> matchers)
41+
// AllOf takes a vector of DynMatchers and returns true if all the DynMatchers match the given operation.
42+
struct AllOfMatcher {
43+
AllOfMatcher(std::vector<matcher::DynMatcher> matchers)
4744
: matchers(matchers) {}
4845
bool match(Operation *op) {
4946
matcher::DynTypedNode node = matcher::DynTypedNode::create(op);
@@ -54,6 +51,19 @@ struct OperationMatcher {
5451
std::vector<matcher::DynMatcher> matchers;
5552
};
5653

54+
// AnyOf takes a vector of DynMatchers and returns true if any of the DynMatchers match the given operation.
55+
struct AnyOfMatcher {
56+
AnyOfMatcher(std::vector<matcher::DynMatcher> matchers)
57+
: matchers(matchers) {}
58+
bool match(Operation *op) {
59+
matcher::DynTypedNode node = matcher::DynTypedNode::create(op);
60+
return llvm::any_of(matchers, [&](const matcher::DynMatcher &matcher) {
61+
return matcher.matches(node);
62+
});
63+
}
64+
std::vector<matcher::DynMatcher> matchers;
65+
};
66+
5767
// ArgumentMatcher matches the operand of an operation at a specific index.
5868
struct ArgumentMatcher {
5969
ArgumentMatcher(matcher::DynMatcher innerMatcher, unsigned index)
@@ -145,10 +155,14 @@ struct DefinedByMatcher {
145155

146156
} // namespace detail
147157

148-
// TODO: Rename to allOf()
149-
inline detail::OperationMatcher operation(matcher::DynMatcher args...) {
158+
inline detail::AllOfMatcher allOf(matcher::DynMatcher args...) {
159+
std::vector<matcher::DynMatcher> matchers({args});
160+
return detail::AllOfMatcher(matchers);
161+
}
162+
163+
inline detail::AnyOfMatcher anyOf(matcher::DynMatcher args...) {
150164
std::vector<matcher::DynMatcher> matchers({args});
151-
return detail::OperationMatcher(matchers);
165+
return detail::AnyOfMatcher(matchers);
152166
}
153167

154168
inline detail::ArgumentMatcher hasArgument(matcher::DynMatcher innerMatcher,

mlir/lib/Tools/mlir-query/MatchersInternal.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ class SingleMatcher : public MatcherInterface<T> {
122122
MatcherFn matcherFn;
123123
};
124124

125-
// VariadicMatcher takes a vector of Matchers and returns true if all Matchers
125+
// TODO: Use a polymorphic matcher instead for this usecase
126+
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
126127
// match the given operation.
127128
template <typename T>
128129
class VariadicMatcher : public MatcherInterface<T> {
@@ -131,7 +132,7 @@ class VariadicMatcher : public MatcherInterface<T> {
131132

132133
bool matches(T Node) override {
133134
DynTypedNode DynNode = DynTypedNode::create(Node);
134-
return llvm::all_of(matchers, [&](const DynMatcher &matcher) {
135+
return llvm::any_of(matchers, [&](const DynMatcher &matcher) {
135136
return matcher.matches(DynNode);
136137
});
137138
}

mlir/lib/Tools/mlir-query/Registry.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ RegistryMaps::RegistryMaps() {
6868
};
6969

7070
// Register matchers using the template function
71-
registerOpMatcher("operation", extramatcher::operation);
71+
registerOpMatcher("allOf", extramatcher::allOf);
72+
registerOpMatcher("anyOf", extramatcher::anyOf);
7273
registerOpMatcher("hasArgument", extramatcher::hasArgument);
7374
registerOpMatcher("definedBy", extramatcher::definedBy);
7475
registerOpMatcher("getDefinitions", extramatcher::getDefinitions);

0 commit comments

Comments
 (0)