Skip to content

Commit e3150e2

Browse files
committed
Move registry to tools
1 parent 30a563f commit e3150e2

File tree

7 files changed

+107
-100
lines changed

7 files changed

+107
-100
lines changed

mlir/include/mlir/Query/Matcher/Parser.h

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class Parser {
8484
// process tokens.
8585
class RegistrySema : public Parser::Sema {
8686
public:
87+
RegistrySema(const RegistryMaps &registryData)
88+
: registryData(registryData) {}
8789
~RegistrySema() override;
8890

8991
std::optional<MatcherCtor>
@@ -99,60 +101,55 @@ class Parser {
99101

100102
std::vector<MatcherCompletion>
101103
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) override;
104+
105+
private:
106+
const RegistryMaps &registryData;
102107
};
103108

104109
using NamedValueMap = llvm::StringMap<VariantValue>;
105110

106111
// Methods to parse a matcher expression and return a DynMatcher object,
107112
// transferring ownership to the caller.
108113
static std::optional<DynMatcher>
109-
parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema,
114+
parseMatcherExpression(llvm::StringRef &matcherCode,
115+
const RegistryMaps &registryData,
110116
const NamedValueMap *namedValues, Diagnostics *error);
111117
static std::optional<DynMatcher>
112-
parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema,
113-
Diagnostics *error) {
114-
return parseMatcherExpression(matcherCode, sema, nullptr, error);
115-
}
116-
static std::optional<DynMatcher>
117-
parseMatcherExpression(llvm::StringRef &matcherCode, Diagnostics *error) {
118-
return parseMatcherExpression(matcherCode, nullptr, error);
118+
parseMatcherExpression(llvm::StringRef &matcherCode,
119+
const RegistryMaps &registryData, Diagnostics *error) {
120+
return parseMatcherExpression(matcherCode, registryData, nullptr, error);
119121
}
120122

121123
// Methods to parse any expression supported by this parser.
122-
static bool parseExpression(llvm::StringRef &code, Sema *sema,
124+
static bool parseExpression(llvm::StringRef &code,
125+
const RegistryMaps &registryData,
123126
const NamedValueMap *namedValues,
124127
VariantValue *value, Diagnostics *error);
125128

126-
static bool parseExpression(llvm::StringRef &code, Sema *sema,
129+
static bool parseExpression(llvm::StringRef &code,
130+
const RegistryMaps &registryData,
127131
VariantValue *value, Diagnostics *error) {
128-
return parseExpression(code, sema, nullptr, value, error);
129-
}
130-
static bool parseExpression(llvm::StringRef &code, VariantValue *value,
131-
Diagnostics *error) {
132-
return parseExpression(code, nullptr, value, error);
132+
return parseExpression(code, registryData, nullptr, value, error);
133133
}
134134

135135
// Methods to complete an expression at a given offset.
136136
static std::vector<MatcherCompletion>
137137
completeExpression(llvm::StringRef &code, unsigned completionOffset,
138-
Sema *sema, const NamedValueMap *namedValues);
138+
const RegistryMaps &registryData,
139+
const NamedValueMap *namedValues);
139140
static std::vector<MatcherCompletion>
140141
completeExpression(llvm::StringRef &code, unsigned completionOffset,
141-
Sema *sema) {
142-
return completeExpression(code, completionOffset, sema, nullptr);
143-
}
144-
static std::vector<MatcherCompletion>
145-
completeExpression(llvm::StringRef &code, unsigned completionOffset) {
146-
return completeExpression(code, completionOffset, nullptr);
142+
const RegistryMaps &registryData) {
143+
return completeExpression(code, completionOffset, registryData, nullptr);
147144
}
148145

149146
private:
150147
class CodeTokenizer;
151148
struct ScopedContextEntry;
152149
struct TokenInfo;
153150

154-
Parser(CodeTokenizer *tokenizer, Sema *sema, const NamedValueMap *namedValues,
155-
Diagnostics *error);
151+
Parser(CodeTokenizer *tokenizer, const RegistryMaps &registryData,
152+
const NamedValueMap *namedValues, Diagnostics *error);
156153

157154
bool parseExpressionImpl(VariantValue *value);
158155

@@ -174,7 +171,7 @@ class Parser {
174171
getNamedValueCompletions(ArrayRef<ArgKind> acceptedTypes);
175172

176173
CodeTokenizer *const tokenizer;
177-
Sema *const sema;
174+
std::unique_ptr<RegistrySema> sema;
178175
const NamedValueMap *const namedValues;
179176
Diagnostics *const error;
180177

mlir/include/mlir/Query/Matcher/Registry.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,36 @@
1919
#include "Marshallers.h"
2020
#include "VariantValue.h"
2121
#include "llvm/ADT/ArrayRef.h"
22+
#include "llvm/ADT/StringMap.h"
2223
#include "llvm/ADT/StringRef.h"
2324
#include <string>
2425

2526
namespace mlir::query::matcher {
2627

2728
using MatcherCtor = const internal::MatcherDescriptor *;
29+
using ConstructorMap =
30+
llvm::StringMap<std::unique_ptr<const internal::MatcherDescriptor>>;
31+
32+
class RegistryMaps {
33+
public:
34+
RegistryMaps() = default;
35+
~RegistryMaps() = default;
36+
37+
const ConstructorMap &constructors() const { return constructorMap; }
38+
39+
template <typename MatcherType>
40+
void registerMatcher(const std::string &name, MatcherType matcher) {
41+
registerMatcherDescriptor(name,
42+
internal::makeMatcherAutoMarshall(matcher, name));
43+
}
44+
45+
private:
46+
void registerMatcherDescriptor(
47+
llvm::StringRef matcherName,
48+
std::unique_ptr<internal::MatcherDescriptor> callback);
49+
50+
ConstructorMap constructorMap;
51+
};
2852

2953
struct MatcherCompletion {
3054
MatcherCompletion() = default;
@@ -47,13 +71,15 @@ class Registry {
4771
Registry() = delete;
4872

4973
static std::optional<MatcherCtor>
50-
lookupMatcherCtor(llvm::StringRef matcherName);
74+
lookupMatcherCtor(llvm::StringRef matcherName,
75+
const RegistryMaps &registryData);
5176

5277
static std::vector<ArgKind> getAcceptedCompletionTypes(
5378
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context);
5479

5580
static std::vector<MatcherCompletion>
56-
getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes);
81+
getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes,
82+
const RegistryMaps &registryData);
5783

5884
static VariantMatcher constructMatcher(MatcherCtor ctor,
5985
internal::SourceRange nameRange,

mlir/include/mlir/Query/QuerySession.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
1111

1212
#include "Query.h"
13+
#include "mlir/Query/Matcher/Registry.h"
1314
#include "mlir/Tools/ParseUtilities.h"
1415
#include "llvm/ADT/StringMap.h"
1516

@@ -20,9 +21,10 @@ class QuerySession {
2021
public:
2122
QuerySession(Operation *rootOp,
2223
const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
23-
unsigned bufferId)
24+
unsigned bufferId,
25+
const mlir::query::matcher::RegistryMaps &registryData)
2426
: rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId),
25-
terminate(false) {}
27+
registryData(registryData), terminate(false) {}
2628

2729
const std::shared_ptr<llvm::SourceMgr> &getSourceManager() {
2830
return sourceMgr;
@@ -31,6 +33,7 @@ class QuerySession {
3133
Operation *rootOp;
3234
const std::shared_ptr<llvm::SourceMgr> sourceMgr;
3335
unsigned bufferId;
36+
const mlir::query::matcher::RegistryMaps &registryData;
3437
bool terminate;
3538
llvm::StringMap<matcher::VariantValue> namedValues;
3639
};

mlir/lib/Query/Matcher/Parser.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -463,18 +463,16 @@ bool Parser::parseExpressionImpl(VariantValue *value) {
463463
llvm_unreachable("Unknown token kind.");
464464
}
465465

466-
static llvm::ManagedStatic<Parser::RegistrySema> defaultRegistrySema;
467-
468-
Parser::Parser(CodeTokenizer *tokenizer, Sema *sema,
466+
Parser::Parser(CodeTokenizer *tokenizer, const RegistryMaps &registryData,
469467
const NamedValueMap *namedValues, Diagnostics *error)
470-
: tokenizer(tokenizer), sema(sema ? sema : &*defaultRegistrySema),
468+
: tokenizer(tokenizer), sema(std::make_unique<RegistrySema>(registryData)),
471469
namedValues(namedValues), error(error) {}
472470

473471
Parser::RegistrySema::~RegistrySema() = default;
474472

475473
std::optional<MatcherCtor>
476474
Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
477-
return Registry::lookupMatcherCtor(matcherName);
475+
return Registry::lookupMatcherCtor(matcherName, registryData);
478476
}
479477

480478
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
@@ -490,14 +488,15 @@ std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
490488

491489
std::vector<MatcherCompletion>
492490
Parser::RegistrySema::getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes) {
493-
return Registry::getMatcherCompletions(acceptedTypes);
491+
return Registry::getMatcherCompletions(acceptedTypes, registryData);
494492
}
495493

496-
bool Parser::parseExpression(llvm::StringRef &code, Sema *sema,
494+
bool Parser::parseExpression(llvm::StringRef &code,
495+
const RegistryMaps &registryData,
497496
const NamedValueMap *namedValues,
498497
VariantValue *value, Diagnostics *error) {
499498
CodeTokenizer tokenizer(code, error);
500-
Parser parser(&tokenizer, sema, namedValues, error);
499+
Parser parser(&tokenizer, registryData, namedValues, error);
501500
if (!parser.parseExpressionImpl(value))
502501
return false;
503502
auto nextToken = tokenizer.peekNextToken();
@@ -512,22 +511,22 @@ bool Parser::parseExpression(llvm::StringRef &code, Sema *sema,
512511

513512
std::vector<MatcherCompletion>
514513
Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset,
515-
Sema *sema, const NamedValueMap *namedValues) {
514+
const RegistryMaps &registryData,
515+
const NamedValueMap *namedValues) {
516516
Diagnostics error;
517517
CodeTokenizer tokenizer(code, &error, completionOffset);
518-
Parser parser(&tokenizer, sema, namedValues, &error);
518+
Parser parser(&tokenizer, registryData, namedValues, &error);
519519
VariantValue dummy;
520520
parser.parseExpressionImpl(&dummy);
521521

522522
return parser.completions;
523523
}
524524

525-
std::optional<DynMatcher>
526-
Parser::parseMatcherExpression(llvm::StringRef &code, Sema *sema,
527-
const NamedValueMap *namedValues,
528-
Diagnostics *error) {
525+
std::optional<DynMatcher> Parser::parseMatcherExpression(
526+
llvm::StringRef &code, const RegistryMaps &registryData,
527+
const NamedValueMap *namedValues, Diagnostics *error) {
529528
VariantValue value;
530-
if (!parseExpression(code, sema, namedValues, &value, error))
529+
if (!parseExpression(code, registryData, namedValues, &value, error))
531530
return std::nullopt;
532531
if (!value.isMatcher()) {
533532
error->addError(SourceRange(), Diagnostics::ErrorType::ParserNotAMatcher);

mlir/lib/Query/Matcher/Registry.cpp

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,12 @@
1212

1313
#include "mlir/Query/Matcher/Registry.h"
1414

15-
#include "mlir/IR/Matchers.h"
16-
#include "llvm/ADT/StringMap.h"
17-
#include "llvm/Support/ManagedStatic.h"
1815
#include <set>
1916
#include <utility>
2017

2118
namespace mlir::query::matcher {
2219
namespace {
2320

24-
using ConstructorMap =
25-
llvm::StringMap<std::unique_ptr<const internal::MatcherDescriptor>>;
26-
2721
// This is needed because these matchers are defined as overloaded functions.
2822
using IsConstantOp = detail::constant_op_matcher();
2923
using HasOpAttrName = detail::AttrOpMatcher(StringRef);
@@ -40,60 +34,21 @@ static std::string asArgString(ArgKind kind) {
4034
llvm_unreachable("Unhandled ArgKind");
4135
}
4236

43-
class RegistryMaps {
44-
public:
45-
RegistryMaps();
46-
~RegistryMaps();
47-
48-
const ConstructorMap &constructors() const { return constructorMap; }
49-
50-
private:
51-
void registerMatcher(llvm::StringRef matcherName,
52-
std::unique_ptr<internal::MatcherDescriptor> callback);
53-
54-
ConstructorMap constructorMap;
55-
};
56-
5737
} // namespace
5838

59-
void RegistryMaps::registerMatcher(
39+
void RegistryMaps::registerMatcherDescriptor(
6040
llvm::StringRef matcherName,
6141
std::unique_ptr<internal::MatcherDescriptor> callback) {
6242
assert(!constructorMap.contains(matcherName));
6343
constructorMap[matcherName] = std::move(callback);
6444
}
6545

66-
// Generate a registry map with all the known matchers.
67-
RegistryMaps::RegistryMaps() {
68-
auto registerOpMatcher = [&](const std::string &name, auto matcher) {
69-
registerMatcher(name, internal::makeMatcherAutoMarshall(matcher, name));
70-
};
71-
72-
// Register matchers using the template function (added in alphabetical order
73-
// for consistency)
74-
registerOpMatcher("hasOpAttrName", static_cast<HasOpAttrName *>(m_Attr));
75-
registerOpMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
76-
registerOpMatcher("isConstantOp", static_cast<IsConstantOp *>(m_Constant));
77-
registerOpMatcher("isNegInfFloat", m_NegInfFloat);
78-
registerOpMatcher("isNegZeroFloat", m_NegZeroFloat);
79-
registerOpMatcher("isNonZero", m_NonZero);
80-
registerOpMatcher("isOne", m_One);
81-
registerOpMatcher("isOneFloat", m_OneFloat);
82-
registerOpMatcher("isPosInfFloat", m_PosInfFloat);
83-
registerOpMatcher("isPosZeroFloat", m_PosZeroFloat);
84-
registerOpMatcher("isZero", m_Zero);
85-
registerOpMatcher("isZeroFloat", m_AnyZeroFloat);
86-
}
87-
88-
RegistryMaps::~RegistryMaps() = default;
89-
90-
static llvm::ManagedStatic<RegistryMaps> registryData;
91-
9246
std::optional<MatcherCtor>
93-
Registry::lookupMatcherCtor(llvm::StringRef matcherName) {
94-
auto it = registryData->constructors().find(matcherName);
95-
return it == registryData->constructors().end() ? std::optional<MatcherCtor>()
96-
: it->second.get();
47+
Registry::lookupMatcherCtor(llvm::StringRef matcherName,
48+
const RegistryMaps &registryData) {
49+
auto it = registryData.constructors().find(matcherName);
50+
return it == registryData.constructors().end() ? std::optional<MatcherCtor>()
51+
: it->second.get();
9752
}
9853

9954
std::vector<ArgKind> Registry::getAcceptedCompletionTypes(
@@ -118,11 +73,12 @@ std::vector<ArgKind> Registry::getAcceptedCompletionTypes(
11873
}
11974

12075
std::vector<MatcherCompletion>
121-
Registry::getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes) {
76+
Registry::getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes,
77+
const RegistryMaps &registryData) {
12278
std::vector<MatcherCompletion> completions;
12379

12480
// Search the registry for acceptable matchers.
125-
for (const auto &m : registryData->constructors()) {
81+
for (const auto &m : registryData.constructors()) {
12682
const internal::MatcherDescriptor &matcher = *m.getValue();
12783
StringRef name = m.getKey();
12884

mlir/lib/Query/QueryParser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ makeInvalidQueryFromDiagnostics(const matcher::internal::Diagnostics &diag) {
134134
QueryRef QueryParser::completeMatcherExpression() {
135135
std::vector<matcher::MatcherCompletion> comps =
136136
matcher::internal::Parser::completeExpression(
137-
line, completionPos - line.begin(), nullptr, &QS.namedValues);
137+
line, completionPos - line.begin(), QS.registryData, &QS.namedValues);
138138
for (const auto &comp : comps) {
139139
completions.emplace_back(comp.typedText, comp.matcherDecl);
140140
}
@@ -175,7 +175,7 @@ QueryRef QueryParser::doParse() {
175175
auto origMatcherSource = matcherSource;
176176
std::optional<matcher::DynMatcher> matcher =
177177
matcher::internal::Parser::parseMatcherExpression(
178-
matcherSource, nullptr, &QS.namedValues, &diag);
178+
matcherSource, QS.registryData, &QS.namedValues, &diag);
179179
if (!matcher) {
180180
return makeInvalidQueryFromDiagnostics(diag);
181181
}

0 commit comments

Comments
 (0)