Skip to content

Commit 4523a6f

Browse files
committed
Fix #16405 - wildcards prematurely resolving to Nothing
This was a problem because it could it get in the way of some metaprogramming techniques. The main issue was the fact that when typing functions, the type inference would first look at the types from the source method (resolving type wildcards to Nothing) and only after that, it could look at the target method. Now, in the case of wildcards we save that fact for later (while still resolving the prototype parameter to Nothing) and we in that case we prioritize according to the target method, after which we fallback to the default procedure.
1 parent 4afb0fc commit 4523a6f

File tree

2 files changed

+63
-15
lines changed

2 files changed

+63
-15
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
11971197
)
11981198
end typedIf
11991199

1200-
/** Decompose function prototype into a list of parameter prototypes and a result prototype
1200+
/** Decompose function prototype into a list of parameter prototypes, an optional list
1201+
* describing whether the parameter prototypes come from WildcardTypes, and a result prototype
12011202
* tree, using WildcardTypes where a type is not known.
12021203
* For the result type we do this even if the expected type is not fully
12031204
* defined, which is a bit of a hack. But it's needed to make the following work
@@ -1206,7 +1207,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12061207
* def double(x: Char): String = s"$x$x"
12071208
* "abc" flatMap double
12081209
*/
1209-
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = {
1210+
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], Option[List[Boolean]], untpd.Tree) = {
12101211
def typeTree(tp: Type) = tp match {
12111212
case _: WildcardType => new untpd.InferredTypeTree()
12121213
case _ => untpd.InferredTypeTree(tp)
@@ -1234,18 +1235,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12341235
// if expected parameter type(s) are wildcards, approximate from below.
12351236
// if expected result type is a wildcard, approximate from above.
12361237
// this can type the greatest set of admissible closures.
1237-
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
1238+
// However, we still keep the information on whether expected parameter types were
1239+
// wildcards, in case of types inferred from target being more specific
1240+
1241+
val fromWildcards = pt1.argInfos.init.map{
1242+
case bounds @ TypeBounds(nt, at) if nt == defn.NothingType && at == defn.AnyType => true
1243+
case bounds => false
1244+
}
1245+
1246+
(pt1.argTypesLo.init, Some(fromWildcards), typeTree(interpolateWildcards(pt1.argTypesHi.last)))
12381247
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
12391248
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
1240-
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
1249+
(formals, None, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
12411250
case SAMType(mt @ MethodTpe(_, formals, restpe)) =>
1242-
(formals,
1251+
(formals, None,
12431252
if (mt.isResultDependent)
12441253
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
12451254
else
12461255
typeTree(restpe))
12471256
case _ =>
1248-
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1257+
(List.tabulate(defaultArity)(alwaysWildcardType), None, untpd.TypeTree())
12491258
}
12501259
}
12511260
}
@@ -1267,7 +1276,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12671276
* If both attempts fail, return `NoType`.
12681277
*/
12691278
def inferredFromTarget(
1270-
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type =
1279+
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int, isWildcardParam: Boolean)(using Context): Type =
12711280
val target = calleeType.widen match
12721281
case mtpe: MethodType =>
12731282
val pos = paramIndex(param.name)
@@ -1280,7 +1289,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12801289
else NoType
12811290
case _ => NoType
12821291
if target.exists then formal <:< target
1283-
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1292+
if !isWildcardParam && isFullyDefined(formal, ForceDegree.flipBottom) then formal
12841293
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
12851294
else NoType
12861295

@@ -1457,7 +1466,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
14571466
case _ =>
14581467
}
14591468

1460-
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
1469+
val (protoFormals, areWildcardParams, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
14611470

14621471
def protoFormal(i: Int): Type =
14631472
if (protoFormals.length == params.length) protoFormals(i)
@@ -1500,13 +1509,21 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15001509
if (!param.tpt.isEmpty) param
15011510
else
15021511
val formal = protoFormal(i)
1512+
val isWildcardParam = areWildcardParams.map(list => if i < list.length then list(i) else false).getOrElse(false)
15031513
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
1504-
val paramType =
1505-
if knownFormal then formal
1506-
else inferredFromTarget(param, formal, calleeType, paramIndex)
1507-
.orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
1514+
// Since decomposeProtoFunction eagerly approximates function arguments
1515+
// from below, then in the case that the argument was also identified as
1516+
// a wildcard type we try to prioritize inferring from target, if possible.
1517+
// See issue 16405 (tests/run/16405.scala)
1518+
val (usingFormal, paramType) =
1519+
if !isWildcardParam && knownFormal then (true, formal)
1520+
else
1521+
val fromTarget = inferredFromTarget(param, formal, calleeType, paramIndex, isWildcardParam)
1522+
if fromTarget.exists then (false, fromTarget)
1523+
else if knownFormal then (true, formal)
1524+
else (false, errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
15081525
val paramTpt = untpd.TypedSplice(
1509-
(if knownFormal then InferredTypeTree() else untpd.TypeTree())
1526+
(if usingFormal then InferredTypeTree() else untpd.TypeTree())
15101527
.withType(paramType.translateFromRepeated(toArray = false))
15111528
.withSpan(param.span.endPos)
15121529
)
@@ -1577,7 +1594,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15771594
typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt)
15781595
}
15791596
else {
1580-
val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
1597+
val (protoFormals, _, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
15811598
val checkMode =
15821599
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
15831600
else desugar.MatchCheck.Exhaustive

tests/run/16405.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import scala.compiletime.summonInline
2+
3+
case class TypeDesc[T](tpe: String)
4+
object TypeDesc {
5+
given nothing: TypeDesc[Nothing] = TypeDesc("Nothing")
6+
given string: TypeDesc[String] = TypeDesc("String")
7+
given int: TypeDesc[Int] = TypeDesc("Int")
8+
}
9+
10+
def exampleFn(s: String, i: Int): Unit = ()
11+
12+
inline def argumentTypesOf[R](fun: (_, _) => R): (TypeDesc[?], TypeDesc[?]) = {
13+
inline fun match {
14+
case x: ((a, b) => R) =>
15+
(scala.compiletime.summonInline[TypeDesc[a]], scala.compiletime.summonInline[TypeDesc[b]])
16+
}
17+
}
18+
inline def argumentTypesOfNoWildCard[A, B, R](fun: (A, B) => R): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
19+
inline def argumentTypesOfAllWildCard(fun: (?, ?) => ?): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
20+
21+
object Test {
22+
def main(args: Array[String]): Unit = {
23+
val expected = (TypeDesc.string, TypeDesc.int)
24+
assert(argumentTypesOf(exampleFn) == expected)
25+
assert(argumentTypesOf(exampleFn(_, _)) == expected)
26+
assert(argumentTypesOfNoWildCard(exampleFn) == expected)
27+
assert(argumentTypesOfNoWildCard(exampleFn(_, _)) == expected)
28+
assert(argumentTypesOfAllWildCard(exampleFn) == expected)
29+
assert(argumentTypesOfAllWildCard(exampleFn(_, _)) == expected)
30+
}
31+
}

0 commit comments

Comments
 (0)