Skip to content

Commit be0844e

Browse files
authored
Fix #16405 ctd - wildcards prematurely resolving to Nothing (#16764)
Fixes #16405, which was a problem because it could it get in the way of some metaprogramming techniques. The main issue was the fact that when typing functions, the type inference would first look at the types from the source method (in decomposeProtoFunction resolving found type wildcards to Nothing) and only after that, it could look at the target method. This is a continuation and simplification of #16625
2 parents 3d38e27 + 1978b8b commit be0844e

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,8 +1217,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12171217
)
12181218
end typedIf
12191219

1220-
/** Decompose function prototype into a list of parameter prototypes and a result prototype
1221-
* tree, using WildcardTypes where a type is not known.
1220+
/** Decompose function prototype into a list of parameter prototypes and a result
1221+
* prototype tree, using WildcardTypes where a type is not known.
1222+
* Note: parameter prototypes may be TypeBounds.
12221223
* For the result type we do this even if the expected type is not fully
12231224
* defined, which is a bit of a hack. But it's needed to make the following work
12241225
* (see typers.scala and printers/PlainPrinter.scala for examples).
@@ -1254,7 +1255,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12541255
// if expected parameter type(s) are wildcards, approximate from below.
12551256
// if expected result type is a wildcard, approximate from above.
12561257
// this can type the greatest set of admissible closures.
1257-
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
1258+
1259+
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
12581260
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
12591261
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
12601262
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
@@ -1300,7 +1302,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13001302
else NoType
13011303
case _ => NoType
13021304
if target.exists then formal <:< target
1303-
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1305+
if !formal.isExactlyNothing && isFullyDefined(formal, ForceDegree.flipBottom) then formal
13041306
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
13051307
else NoType
13061308

@@ -1493,11 +1495,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
14931495
}
14941496

14951497
var desugared: untpd.Tree = EmptyTree
1496-
if protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head) then
1497-
val isGenericTuple =
1498-
protoFormals.head.derivesFrom(defn.TupleClass)
1499-
&& !defn.isTupleClass(protoFormals.head.typeSymbol)
1500-
desugared = desugar.makeTupledFunction(params, fnBody, isGenericTuple)
1498+
if protoFormals.length == 1 && params.length != 1 then
1499+
val firstFormal = protoFormals.head.loBound
1500+
if ptIsCorrectProduct(firstFormal) then
1501+
val isGenericTuple =
1502+
firstFormal.derivesFrom(defn.TupleClass)
1503+
&& !defn.isTupleClass(firstFormal.typeSymbol)
1504+
desugared = desugar.makeTupledFunction(params, fnBody, isGenericTuple)
15011505
else if protoFormals.length > 1 && params.length == 1 then
15021506
def isParamRef(scrut: untpd.Tree): Boolean = scrut match
15031507
case untpd.Annotated(scrut1, _) => isParamRef(scrut1)
@@ -1519,12 +1523,20 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15191523
for ((param, i) <- params.zipWithIndex) yield
15201524
if (!param.tpt.isEmpty) param
15211525
else
1522-
val formal = protoFormal(i)
1526+
val formalBounds = protoFormal(i)
1527+
val formal = formalBounds.loBound
1528+
val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing
15231529
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
1530+
// If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound,
1531+
// try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala)
15241532
val paramType =
1525-
if knownFormal then formal
1526-
else inferredFromTarget(param, formal, calleeType, paramIndex)
1527-
.orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
1533+
if knownFormal && !isBottomFromWildcard then
1534+
formal
1535+
else
1536+
inferredFromTarget(param, formal, calleeType, paramIndex).orElse(
1537+
if knownFormal then formal
1538+
else errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos)
1539+
)
15281540
val paramTpt = untpd.TypedSplice(
15291541
(if knownFormal then InferredTypeTree() else untpd.TypeTree())
15301542
.withType(paramType.translateFromRepeated(toArray = false))

tests/run/16405.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import scala.compiletime.summonInline
2+
3+
case class TypeDesc[T](tpe: String)
4+
object TypeDesc {
5+
given nothing: TypeDesc[Nothing] = TypeDesc("Nothing")
6+
given string: TypeDesc[String] = TypeDesc("String")
7+
given int: TypeDesc[Int] = TypeDesc("Int")
8+
}
9+
10+
def exampleFn(s: String, i: Int): Unit = ()
11+
12+
inline def argumentTypesOf[R](fun: (_, _) => R): (TypeDesc[?], TypeDesc[?]) = {
13+
inline fun match {
14+
case x: ((a, b) => R) =>
15+
(scala.compiletime.summonInline[TypeDesc[a]], scala.compiletime.summonInline[TypeDesc[b]])
16+
}
17+
}
18+
inline def argumentTypesOfNoWildCard[A, B, R](fun: (A, B) => R): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
19+
inline def argumentTypesOfAllWildCard(fun: (?, ?) => ?): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
20+
21+
object Test {
22+
def main(args: Array[String]): Unit = {
23+
val expected = (TypeDesc.string, TypeDesc.int)
24+
assert(argumentTypesOf(exampleFn) == expected)
25+
assert(argumentTypesOf(exampleFn(_, _)) == expected)
26+
assert(argumentTypesOfNoWildCard(exampleFn) == expected)
27+
assert(argumentTypesOfNoWildCard(exampleFn(_, _)) == expected)
28+
assert(argumentTypesOfAllWildCard(exampleFn) == expected)
29+
assert(argumentTypesOfAllWildCard(exampleFn(_, _)) == expected)
30+
}
31+
}

0 commit comments

Comments
 (0)