Skip to content

Commit a3a3ec4

Browse files
committed
[CSOptimizer] Restore old hack behavior which used to favor overloads based on arity matches
This maintains an "old hack" behavior where overloads of some `OverloadedDeclRef` calls were favored purely based on number of argument and (non-defaulted) parameters matching. This is important to maintain source compatibility.
1 parent 802f5cd commit a3a3ec4

File tree

3 files changed

+129
-13
lines changed

3 files changed

+129
-13
lines changed

lib/Sema/CSOptimizer.cpp

+79-5
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,67 @@ void forEachDisjunctionChoice(
172172
}
173173
}
174174

175-
static bool isOverloadedDeclRef(Constraint *disjunction) {
175+
static OverloadedDeclRefExpr *isOverloadedDeclRef(Constraint *disjunction) {
176176
assert(disjunction->getKind() == ConstraintKind::Disjunction);
177-
return disjunction->getLocator()->directlyAt<OverloadedDeclRefExpr>();
177+
178+
auto *locator = disjunction->getLocator();
179+
if (locator->getPath().empty())
180+
return getAsExpr<OverloadedDeclRefExpr>(locator->getAnchor());
181+
return nullptr;
182+
}
183+
184+
/// This maintains an "old hack" behavior where overloads of some
185+
/// `OverloadedDeclRef` calls were favored purely based on number of
186+
/// argument and (non-defaulted) parameters matching.
187+
static void findFavoredChoicesBasedOnArity(
188+
ConstraintSystem &cs, Constraint *disjunction, ArgumentList *argumentList,
189+
llvm::function_ref<void(Constraint *)> favoredChoice) {
190+
auto *ODRE = isOverloadedDeclRef(disjunction);
191+
if (!ODRE)
192+
return;
193+
194+
if (llvm::count_if(ODRE->getDecls(), [&argumentList](auto *choice) {
195+
if (auto *paramList = getParameterList(choice))
196+
return argumentList->size() == paramList->size();
197+
return false;
198+
}) > 1)
199+
return;
200+
201+
auto isVariadicGenericOverload = [&](ValueDecl *choice) {
202+
auto genericContext = choice->getAsGenericContext();
203+
if (!genericContext)
204+
return false;
205+
206+
auto *GPL = genericContext->getGenericParams();
207+
if (!GPL)
208+
return false;
209+
210+
return llvm::any_of(GPL->getParams(), [&](const GenericTypeParamDecl *GP) {
211+
return GP->isParameterPack();
212+
});
213+
};
214+
215+
bool hasVariadicGenerics = false;
216+
SmallVector<Constraint *> favored;
217+
218+
forEachDisjunctionChoice(
219+
cs, disjunction,
220+
[&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
221+
if (isVariadicGenericOverload(decl))
222+
hasVariadicGenerics = true;
223+
224+
if (overloadType->getNumParams() == argumentList->size() ||
225+
llvm::count_if(*getParameterList(decl), [](auto *param) {
226+
return !param->isDefaultArgument();
227+
}) == argumentList->size())
228+
favored.push_back(choice);
229+
});
230+
231+
if (hasVariadicGenerics)
232+
return;
233+
234+
for (auto *choice : favored)
235+
favoredChoice(choice);
178236
}
179237

180238
} // end anonymous namespace
@@ -193,9 +251,6 @@ static Constraint *determineBestChoicesInContext(
193251
favoredChoicesPerDisjunction;
194252

195253
for (auto *disjunction : disjunctions) {
196-
if (!isSupportedDisjunction(disjunction))
197-
continue;
198-
199254
auto applicableFn =
200255
getApplicableFnConstraint(cs.getConstraintGraph(), disjunction);
201256

@@ -218,6 +273,25 @@ static Constraint *determineBestChoicesInContext(
218273
}
219274
}
220275

276+
// This maintains an "old hack" behavior where overloads
277+
// of `OverloadedDeclRef` calls were favored purely
278+
// based on arity of arguments and parameters matching.
279+
{
280+
findFavoredChoicesBasedOnArity(
281+
cs, disjunction, argumentList, [&](Constraint *choice) {
282+
favoredChoicesPerDisjunction[disjunction].push_back(choice);
283+
});
284+
285+
if (!favoredChoicesPerDisjunction[disjunction].empty()) {
286+
disjunctionScores[disjunction] = 0.01;
287+
bestOverallScore = std::max(bestOverallScore, 0.01);
288+
continue;
289+
}
290+
}
291+
292+
if (!isSupportedDisjunction(disjunction))
293+
continue;
294+
221295
SmallVector<FunctionType::Param, 8> argsWithLabels;
222296
{
223297
argsWithLabels.append(argFuncType->getParams().begin(),

test/Constraints/old_hack_related_ambiguities.swift

+50
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,53 @@ do {
187187
let result = test(10, accuracy: 1)
188188
let _: Int = result
189189
}
190+
191+
// swift-distributed-tracing snippet that relies on old hack behavior.
192+
protocol TracerInstant {
193+
}
194+
195+
extension Int: TracerInstant {}
196+
197+
do {
198+
enum SpanKind {
199+
case `internal`
200+
}
201+
202+
func withSpan<Instant: TracerInstant>(
203+
_ operationName: String,
204+
at instant: @autoclosure () -> Instant,
205+
context: @autoclosure () -> Int = 0,
206+
ofKind kind: SpanKind = .internal
207+
) {}
208+
209+
func withSpan(
210+
_ operationName: String,
211+
context: @autoclosure () -> Int = 0,
212+
ofKind kind: SpanKind = .internal,
213+
at instant: @autoclosure () -> some TracerInstant = 42
214+
) {}
215+
216+
withSpan("", at: 0) // Ok
217+
}
218+
219+
protocol ForAssert {
220+
var isEmpty: Bool { get }
221+
}
222+
223+
extension ForAssert {
224+
var isEmpty: Bool { false }
225+
}
226+
227+
do {
228+
func assert(_ condition: @autoclosure () -> Bool, _ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) {}
229+
func assert(_ condition: Bool, _ message: @autoclosure () -> String, file: StaticString = #file, line: UInt = #line) {}
230+
func assert(_ condition: Bool, file: StaticString = #fileID, line: UInt = #line) {}
231+
232+
struct S : ForAssert {
233+
var isEmpty: Bool { false }
234+
}
235+
236+
func test(s: S) {
237+
assert(s.isEmpty, "") // Ok
238+
}
239+
}

test/Constraints/ranking.swift

-8
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,3 @@ struct HasIntInit {
450450
func compare_solutions_with_bindings(x: UInt8, y: UInt8) -> HasIntInit {
451451
return .init(Int(x / numericCast(y)))
452452
}
453-
454-
// Test to make sure that previous favoring behavior is maintained and @autoclosure makes a difference.
455-
func test_no_ambiguity_with_autoclosure(x: Int) {
456-
func test(_ condition: Bool, file: StaticString = #file, line: UInt = #line) {}
457-
func test(_ condition: @autoclosure () -> Bool, file: StaticString = #file, line: UInt = #line) {}
458-
459-
test(x >= 0) // Ok
460-
}

0 commit comments

Comments
 (0)