Skip to content

Commit 8d5cb11

Browse files
committed
[ConstraintSystem] Narrowly disable tryOptimizeGenericDisjunction when some of the arguments are number literals
Don't attempt this optimization if call has number literals. This is intended to narrowly fix situations like: ```swift func test<T: FloatingPoint>(_: T) { ... } func test<T: Numeric>(_: T) { ... } test(42) ``` The call should use `<T: Numeric>` overload even though the `<T: FloatingPoint>` is a more specialized version because selecting `<T: Numeric>` doesn't introduce non-default literal types.
1 parent f2a6677 commit 8d5cb11

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

include/swift/Sema/ConstraintSystem.h

+4
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,10 @@ class TypeVariableType::Implementation {
497497
/// literal (represented by `ArrayExpr` and `DictionaryExpr` in AST).
498498
bool isCollectionLiteralType() const;
499499

500+
/// Determine whether this type variable represents a literal such
501+
/// as an integer value, a floating-point value with and without a sign.
502+
bool isNumberLiteralType() const;
503+
500504
/// Determine whether this type variable represents a result type of a
501505
/// function call.
502506
bool isFunctionResult() const;

lib/Sema/CSSolver.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,27 @@ tryOptimizeGenericDisjunction(ConstraintSystem &cs, Constraint *disjunction,
14021402
return nullptr;
14031403
}
14041404

1405+
// Don't attempt this optimization if call has number literals.
1406+
// This is intended to narrowly fix situations like:
1407+
//
1408+
// func test<T: FloatingPoint>(_: T) { ... }
1409+
// func test<T: Numeric>(_: T) { ... }
1410+
//
1411+
// test(42)
1412+
//
1413+
// The call should use `<T: Numeric>` overload even though the
1414+
// `<T: FloatingPoint>` is a more specialized version because
1415+
// selecting `<T: Numeric>` doesn't introduce non-default literal
1416+
// types.
1417+
if (auto *argFnType = cs.getAppliedDisjunctionArgumentFunction(disjunction)) {
1418+
if (llvm::any_of(
1419+
argFnType->getParams(), [](const AnyFunctionType::Param &param) {
1420+
auto *typeVar = param.getPlainType()->getAs<TypeVariableType>();
1421+
return typeVar && typeVar->getImpl().isNumberLiteralType();
1422+
}))
1423+
return nullptr;
1424+
}
1425+
14051426
llvm::SmallVector<Constraint *, 4> choices;
14061427
for (auto *choice : constraints) {
14071428
if (choices.size() > 2)

lib/Sema/TypeCheckConstraints.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ bool TypeVariableType::Implementation::isCollectionLiteralType() const {
204204
locator->directlyAt<DictionaryExpr>());
205205
}
206206

207+
bool TypeVariableType::Implementation::isNumberLiteralType() const {
208+
return locator && locator->directlyAt<NumberLiteralExpr>();
209+
}
210+
207211
bool TypeVariableType::Implementation::isFunctionResult() const {
208212
return locator && locator->isLastElement<LocatorPathElt::FunctionResult>();
209213
}

test/Constraints/old_hack_related_ambiguities.swift

+9
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,12 @@ do {
178178
}
179179
}
180180

181+
// `tryOptimizeGenericDisjunction` is too aggressive sometimes, make sure that `<T: FloatingPoint>`
182+
// overload is _not_ selected in this case.
183+
do {
184+
func test<T: FloatingPoint>(_ expression1: @autoclosure () throws -> T, accuracy: T) -> T {}
185+
func test<T: Numeric>(_ expression1: @autoclosure () throws -> T, accuracy: T) -> T {}
186+
187+
let result = test(10, accuracy: 1)
188+
let _: Int = result
189+
}

0 commit comments

Comments
 (0)