Skip to content

Commit adbf8da

Browse files
authored
Merge pull request #23489 from AnthonyLatsis/where-clause-nongeneric-decl
[SE] Allow where clauses on non-generic declarations in generic contexts
2 parents 0bbd8de + f762644 commit adbf8da

24 files changed

+692
-228
lines changed

include/swift/AST/Decl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,9 @@ class RequirementRepr {
12281228
void print(raw_ostream &OS) const;
12291229
void print(ASTPrinter &Printer) const;
12301230
};
1231-
1231+
1232+
using GenericParamSource = PointerUnion<GenericContext *, GenericParamList *>;
1233+
12321234
/// GenericParamList - A list of generic parameters that is part of a generic
12331235
/// function or type, along with extra requirements placed on those generic
12341236
/// parameters and types derived from them.

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,10 +1650,9 @@ ERROR(redundant_class_requirement,none,
16501650
"redundant 'class' requirement", ())
16511651
ERROR(late_class_requirement,none,
16521652
"'class' must come first in the requirement list", ())
1653-
ERROR(where_without_generic_params,none,
1654-
"'where' clause cannot be attached to "
1655-
"%select{a non-generic|a protocol|an associated type}0 "
1656-
"declaration", (unsigned))
1653+
ERROR(where_toplevel_nongeneric,none,
1654+
"'where' clause cannot be attached to non-generic "
1655+
"top-level declaration", ())
16571656
ERROR(where_inside_brackets,none,
16581657
"'where' clause next to generic parameters is obsolete, "
16591658
"must be written following the declaration's type", ())

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2751,6 +2751,10 @@ ERROR(dynamic_self_stored_property_init,none,
27512751
ERROR(dynamic_self_default_arg,none,
27522752
"covariant 'Self' type cannot be referenced from a default argument expression", ())
27532753

2754+
ERROR(where_nongeneric_ctx,none,
2755+
"'where' clause on non-generic member declaration requires a "
2756+
"generic context", ())
2757+
27542758
//------------------------------------------------------------------------------
27552759
// MARK: Type Check Attributes
27562760
//------------------------------------------------------------------------------

include/swift/AST/TypeCheckRequests.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ class InferredGenericSignatureRequest :
11081108
public SimpleRequest<InferredGenericSignatureRequest,
11091109
GenericSignature (ModuleDecl *,
11101110
GenericSignatureImpl *,
1111-
GenericParamList *,
1111+
GenericParamSource,
11121112
SmallVector<Requirement, 2>,
11131113
SmallVector<TypeLoc, 2>,
11141114
bool),
@@ -1124,7 +1124,7 @@ class InferredGenericSignatureRequest :
11241124
evaluate(Evaluator &evaluator,
11251125
ModuleDecl *module,
11261126
GenericSignatureImpl *baseSignature,
1127-
GenericParamList *gpl,
1127+
GenericParamSource paramSource,
11281128
SmallVector<Requirement, 2> addedRequirements,
11291129
SmallVector<TypeLoc, 2> inferenceSources,
11301130
bool allowConcreteGenericParams) const;

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ SWIFT_REQUEST(TypeChecker, HasDynamicMemberLookupAttributeRequest,
8383
bool(CanType), Cached, NoLocationInfo)
8484
SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
8585
GenericSignature (ModuleDecl *, GenericSignatureImpl *,
86-
GenericParamList *,
86+
GenericParamSource,
8787
SmallVector<Requirement, 2>,
8888
SmallVector<TypeLoc, 2>, bool),
8989
Cached, NoLocationInfo)

include/swift/Basic/SimpleDisplay.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ namespace swift {
135135
}
136136
out << "}";
137137
}
138+
139+
template<typename T, typename U>
140+
void simple_display(llvm::raw_ostream &out,
141+
const llvm::PointerUnion<T, U> &ptrUnion) {
142+
if (const auto t = ptrUnion.template dyn_cast<T>())
143+
simple_display(out, t);
144+
else
145+
simple_display(out, ptrUnion.template get<U>());
146+
}
138147
}
139148

140149
#endif // SWIFT_BASIC_SIMPLE_DISPLAY_H

include/swift/Parse/Parser.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ class Parser {
10691069
bool allowClassRequirement,
10701070
bool allowAnyObject);
10711071
ParserStatus parseDeclItem(bool &PreviousHadSemi,
1072-
Parser::ParseDeclOptions Options,
1072+
ParseDeclOptions Options,
10731073
llvm::function_ref<void(Decl*)> handler);
10741074
std::pair<std::vector<Decl *>, Optional<std::string>>
10751075
parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
@@ -1637,14 +1637,10 @@ class Parser {
16371637
void
16381638
diagnoseWhereClauseInGenericParamList(const GenericParamList *GenericParams);
16391639

1640-
enum class WhereClauseKind : unsigned {
1641-
Declaration,
1642-
Protocol,
1643-
AssociatedType
1644-
};
16451640
ParserStatus
1646-
parseFreestandingGenericWhereClause(GenericParamList *GPList,
1647-
WhereClauseKind kind=WhereClauseKind::Declaration);
1641+
parseFreestandingGenericWhereClause(GenericContext *genCtx,
1642+
GenericParamList *&GPList,
1643+
ParseDeclOptions flags);
16481644

16491645
ParserStatus parseGenericWhereClause(
16501646
SourceLoc &WhereLoc, SmallVectorImpl<RequirementRepr> &Requirements,

lib/AST/ASTContext.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4470,19 +4470,22 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
44704470
assert(isa<AbstractFunctionDecl>(base) || isa<SubscriptDecl>(base));
44714471
assert(isa<AbstractFunctionDecl>(derived) || isa<SubscriptDecl>(derived));
44724472

4473-
auto baseClass = base->getDeclContext()->getSelfClassDecl();
4474-
auto derivedClass = derived->getDeclContext()->getSelfClassDecl();
4473+
const auto baseClass = base->getDeclContext()->getSelfClassDecl();
4474+
const auto derivedClass = derived->getDeclContext()->getSelfClassDecl();
44754475

44764476
assert(baseClass != nullptr);
44774477
assert(derivedClass != nullptr);
44784478

4479-
auto baseGenericSig = base->getAsGenericContext()->getGenericSignature();
4480-
auto derivedGenericSig = derived->getAsGenericContext()->getGenericSignature();
4479+
const auto baseGenericSig =
4480+
base->getAsGenericContext()->getGenericSignature();
4481+
const auto derivedGenericSig =
4482+
derived->getAsGenericContext()->getGenericSignature();
44814483

44824484
if (base == derived)
44834485
return derivedGenericSig;
44844486

4485-
if (derivedClass->getSuperclass().isNull())
4487+
const auto derivedSuperclass = derivedClass->getSuperclass();
4488+
if (derivedSuperclass.isNull())
44864489
return nullptr;
44874490

44884491
if (derivedGenericSig.isNull())
@@ -4491,12 +4494,6 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
44914494
if (baseGenericSig.isNull())
44924495
return derivedGenericSig;
44934496

4494-
auto baseClassSig = baseClass->getGenericSignature();
4495-
auto subMap = derivedClass->getSuperclass()->getContextSubstitutionMap(
4496-
derivedClass->getModuleContext(), baseClass);
4497-
4498-
unsigned derivedDepth = 0;
4499-
45004497
auto key = OverrideSignatureKey(baseGenericSig,
45014498
derivedGenericSig,
45024499
derivedClass);
@@ -4506,22 +4503,25 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
45064503
return getImpl().overrideSigCache.lookup(key);
45074504
}
45084505

4509-
if (auto derivedSig = derivedClass->getGenericSignature())
4510-
derivedDepth = derivedSig->getGenericParams().back()->getDepth() + 1;
4506+
const auto derivedClassSig = derivedClass->getGenericSignature();
4507+
4508+
unsigned derivedDepth = 0;
4509+
unsigned baseDepth = 0;
4510+
if (derivedClassSig)
4511+
derivedDepth = derivedClassSig->getGenericParams().back()->getDepth() + 1;
4512+
if (const auto baseClassSig = baseClass->getGenericSignature())
4513+
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;
45114514

45124515
SmallVector<GenericTypeParamType *, 2> addedGenericParams;
4513-
if (auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
4516+
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
45144517
for (auto gp : *gpList) {
45154518
addedGenericParams.push_back(
45164519
gp->getDeclaredInterfaceType()->castTo<GenericTypeParamType>());
45174520
}
45184521
}
45194522

4520-
unsigned baseDepth = 0;
4521-
4522-
if (baseClassSig) {
4523-
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;
4524-
}
4523+
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
4524+
derivedClass->getModuleContext(), baseClass);
45254525

45264526
auto substFn = [&](SubstitutableType *type) -> Type {
45274527
auto *gp = cast<GenericTypeParamType>(type);
@@ -4553,7 +4553,7 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
45534553
auto genericSig = evaluateOrDefault(
45544554
evaluator,
45554555
AbstractGenericSignatureRequest{
4556-
derivedClass->getGenericSignature().getPointer(),
4556+
derivedClassSig.getPointer(),
45574557
std::move(addedGenericParams),
45584558
std::move(addedRequirements)},
45594559
GenericSignature());

lib/AST/ASTScopeLookup.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ bool ASTScopeImpl::doesContextMatchStartingContext(
194194
// For a SubscriptDecl with generic parameters, the call tries to do lookups
195195
// with startingContext equal to either the get or set subscript
196196
// AbstractFunctionDecls. Since the generic parameters are in the
197-
// SubScriptDeclScope, and not the AbstractFunctionDecl scopes (after all how
198-
// could one parameter be in two scopes?), GenericParamScoped intercepts the
197+
// SubscriptDeclScope, and not the AbstractFunctionDecl scopes (after all how
198+
// could one parameter be in two scopes?), GenericParamScope intercepts the
199199
// match query here and tests against the accessor DeclContexts.
200200
bool GenericParamScope::doesContextMatchStartingContext(
201201
const DeclContext *context) const {

lib/AST/ASTWalker.cpp

Lines changed: 40 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,31 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
126126
// Decls
127127
//===--------------------------------------------------------------------===//
128128

129+
bool visitGenericParamListIfNeeded(GenericContext *GC) {
130+
// Must check this first in case extensions have not been bound yet
131+
if (Walker.shouldWalkIntoGenericParams()) {
132+
if (auto *params = GC->getGenericParams()) {
133+
visitGenericParamList(params);
134+
}
135+
return true;
136+
}
137+
return false;
138+
}
139+
140+
bool visitTrailingRequirements(GenericContext *GC) {
141+
if (const auto Where = GC->getTrailingWhereClause()) {
142+
for (auto &Req: Where->getRequirements())
143+
if (doIt(Req))
144+
return true;
145+
} else if (!isa<ExtensionDecl>(GC)) {
146+
if (const auto GP = GC->getGenericParams())
147+
for (auto Req: GP->getTrailingRequirements())
148+
if (doIt(Req))
149+
return true;
150+
}
151+
return false;
152+
}
153+
129154
bool visitImportDecl(ImportDecl *ID) {
130155
return false;
131156
}
@@ -138,12 +163,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
138163
if (doIt(Inherit))
139164
return true;
140165
}
141-
if (auto *Where = ED->getTrailingWhereClause()) {
142-
for(auto &Req: Where->getRequirements()) {
143-
if (doIt(Req))
144-
return true;
145-
}
146-
}
166+
if (visitTrailingRequirements(ED))
167+
return true;
168+
147169
for (Decl *M : ED->getMembers()) {
148170
if (doIt(M))
149171
return true;
@@ -223,15 +245,13 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
223245
}
224246

225247
bool visitTypeAliasDecl(TypeAliasDecl *TAD) {
226-
if (Walker.shouldWalkIntoGenericParams() && TAD->getGenericParams()) {
227-
if (visitGenericParamList(TAD->getGenericParams()))
228-
return true;
229-
}
248+
bool WalkGenerics = visitGenericParamListIfNeeded(TAD);
230249

231250
if (auto typerepr = TAD->getUnderlyingTypeRepr())
232251
if (doIt(typerepr))
233252
return true;
234-
return false;
253+
254+
return WalkGenerics && visitTrailingRequirements(TAD);
235255
}
236256

237257
bool visitOpaqueTypeDecl(OpaqueTypeDecl *OTD) {
@@ -269,20 +289,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
269289
}
270290

271291
// Visit requirements
272-
if (WalkGenerics) {
273-
ArrayRef<swift::RequirementRepr> Reqs = None;
274-
if (auto *Protocol = dyn_cast<ProtocolDecl>(NTD)) {
275-
if (auto *WhereClause = Protocol->getTrailingWhereClause())
276-
Reqs = WhereClause->getRequirements();
277-
} else {
278-
Reqs = NTD->getGenericParams()->getTrailingRequirements();
279-
}
280-
for (auto Req: Reqs) {
281-
if (doIt(Req))
282-
return true;
283-
}
284-
}
285-
292+
if (WalkGenerics && visitTrailingRequirements(NTD))
293+
return true;
294+
286295
for (Decl *Member : NTD->getMembers()) {
287296
if (doIt(Member))
288297
return true;
@@ -325,13 +334,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
325334
if (doIt(SD->getElementTypeLoc()))
326335
return true;
327336

328-
if (WalkGenerics) {
329-
// Visit generic requirements
330-
for (auto Req : SD->getGenericParams()->getTrailingRequirements()) {
331-
if (doIt(Req))
332-
return true;
333-
}
334-
}
337+
// Visit trailing requirements
338+
if (WalkGenerics && visitTrailingRequirements(SD))
339+
return true;
335340

336341
if (!Walker.shouldWalkAccessorsTheOldWay()) {
337342
for (auto *AD : SD->getAllAccessors())
@@ -364,13 +369,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
364369
if (doIt(FD->getBodyResultTypeLoc()))
365370
return true;
366371

367-
if (WalkGenerics) {
368-
// Visit trailing requirments
369-
for (auto Req : AFD->getGenericParams()->getTrailingRequirements()) {
370-
if (doIt(Req))
371-
return true;
372-
}
373-
}
372+
// Visit trailing requirements
373+
if (WalkGenerics && visitTrailingRequirements(AFD))
374+
return true;
374375

375376
if (AFD->getBody(/*canSynthesize=*/false)) {
376377
AbstractFunctionDecl::BodyKind PreservedKind = AFD->getBodyKind();
@@ -1323,17 +1324,6 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13231324
}
13241325
return false;
13251326
}
1326-
1327-
private:
1328-
bool visitGenericParamListIfNeeded(GenericContext *gc) {
1329-
if (Walker.shouldWalkIntoGenericParams()) {
1330-
if (auto *params = gc->getGenericParams()) {
1331-
visitGenericParamList(params);
1332-
return true;
1333-
}
1334-
}
1335-
return false;
1336-
}
13371327
};
13381328

13391329
} // end anonymous namespace

lib/AST/Decl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,9 +1077,12 @@ void GenericContext::setGenericSignature(GenericSignature genericSig) {
10771077
}
10781078

10791079
SourceRange GenericContext::getGenericTrailingWhereClauseSourceRange() const {
1080-
if (!isGeneric())
1081-
return SourceRange();
1082-
return getGenericParams()->getTrailingWhereClauseSourceRange();
1080+
if (isGeneric())
1081+
return getGenericParams()->getTrailingWhereClauseSourceRange();
1082+
else if (const auto *where = getTrailingWhereClause())
1083+
return where->getSourceRange();
1084+
1085+
return SourceRange();
10831086
}
10841087

10851088
ImportDecl *ImportDecl::create(ASTContext &Ctx, DeclContext *DC,

0 commit comments

Comments
 (0)