Skip to content

Fix #16405 - wildcards prematurely resolving to Nothing #16625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
)
end typedIf

/** Decompose function prototype into a list of parameter prototypes and a result prototype
/** Decompose function prototype into a list of parameter prototypes, an optional list
* describing whether the parameter prototypes come from WildcardTypes, and a result prototype
* tree, using WildcardTypes where a type is not known.
* For the result type we do this even if the expected type is not fully
* defined, which is a bit of a hack. But it's needed to make the following work
Expand All @@ -1206,7 +1207,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
* def double(x: Char): String = s"$x$x"
* "abc" flatMap double
*/
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = {
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], Option[List[Boolean]], untpd.Tree) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found another way that achieves the same outcome and that does not need a separate list of booleans.

def typeTree(tp: Type) = tp match {
case _: WildcardType => new untpd.InferredTypeTree()
case _ => untpd.InferredTypeTree(tp)
Expand Down Expand Up @@ -1234,18 +1235,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// if expected parameter type(s) are wildcards, approximate from below.
// if expected result type is a wildcard, approximate from above.
// this can type the greatest set of admissible closures.
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
// However, we still keep the information on whether expected parameter types were
// wildcards, in case of types inferred from target being more specific

val fromWildcards = pt1.argInfos.init.map{
case bounds @ TypeBounds(nt, at) if nt == defn.NothingType && at == defn.AnyType => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type comparison via == or != is not recommended. One can compare types with eq, or pattern match the structure and compare symbols. Or else use =:= which is semantic equality, equivalent to mutual subtyping, but this one is expensive.

In this specific case, there are detectors for the types your are interested in: isExactlyNothing and isAny.

case bounds => false
}

(pt1.argTypesLo.init, Some(fromWildcards), typeTree(interpolateWildcards(pt1.argTypesHi.last)))
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
(formals, None, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
case SAMType(mt @ MethodTpe(_, formals, restpe)) =>
(formals,
(formals, None,
if (mt.isResultDependent)
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
else
typeTree(restpe))
case _ =>
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
(List.tabulate(defaultArity)(alwaysWildcardType), None, untpd.TypeTree())
}
}
}
Expand All @@ -1267,7 +1276,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
* If both attempts fail, return `NoType`.
*/
def inferredFromTarget(
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type =
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int, isWildcardParam: Boolean)(using Context): Type =
val target = calleeType.widen match
case mtpe: MethodType =>
val pos = paramIndex(param.name)
Expand All @@ -1280,7 +1289,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
else NoType
case _ => NoType
if target.exists then formal <:< target
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
if !isWildcardParam && isFullyDefined(formal, ForceDegree.flipBottom) then formal
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
else NoType

Expand Down Expand Up @@ -1457,7 +1466,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case _ =>
}

val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
val (protoFormals, areWildcardParams, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)

def protoFormal(i: Int): Type =
if (protoFormals.length == params.length) protoFormals(i)
Expand Down Expand Up @@ -1500,13 +1509,21 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if (!param.tpt.isEmpty) param
else
val formal = protoFormal(i)
val isWildcardParam = areWildcardParams.map(list => if i < list.length then list(i) else false).getOrElse(false)
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
val paramType =
if knownFormal then formal
else inferredFromTarget(param, formal, calleeType, paramIndex)
.orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
// Since decomposeProtoFunction eagerly approximates function arguments
// from below, then in the case that the argument was also identified as
// a wildcard type we try to prioritize inferring from target, if possible.
// See issue 16405 (tests/run/16405.scala)
val (usingFormal, paramType) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why usingFormal is needed here. Why isknownFormal not the right condition?

if !isWildcardParam && knownFormal then (true, formal)
else
val fromTarget = inferredFromTarget(param, formal, calleeType, paramIndex, isWildcardParam)
if fromTarget.exists then (false, fromTarget)
else if knownFormal then (true, formal)
else (false, errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
val paramTpt = untpd.TypedSplice(
(if knownFormal then InferredTypeTree() else untpd.TypeTree())
(if usingFormal then InferredTypeTree() else untpd.TypeTree())
.withType(paramType.translateFromRepeated(toArray = false))
.withSpan(param.span.endPos)
)
Expand Down Expand Up @@ -1577,7 +1594,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt)
}
else {
val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
val (protoFormals, _, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
val checkMode =
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
else desugar.MatchCheck.Exhaustive
Expand Down
31 changes: 31 additions & 0 deletions tests/run/16405.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import scala.compiletime.summonInline

case class TypeDesc[T](tpe: String)
object TypeDesc {
given nothing: TypeDesc[Nothing] = TypeDesc("Nothing")
given string: TypeDesc[String] = TypeDesc("String")
given int: TypeDesc[Int] = TypeDesc("Int")
}

def exampleFn(s: String, i: Int): Unit = ()

inline def argumentTypesOf[R](fun: (_, _) => R): (TypeDesc[?], TypeDesc[?]) = {
inline fun match {
case x: ((a, b) => R) =>
(scala.compiletime.summonInline[TypeDesc[a]], scala.compiletime.summonInline[TypeDesc[b]])
}
}
inline def argumentTypesOfNoWildCard[A, B, R](fun: (A, B) => R): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
inline def argumentTypesOfAllWildCard(fun: (?, ?) => ?): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)

object Test {
def main(args: Array[String]): Unit = {
val expected = (TypeDesc.string, TypeDesc.int)
assert(argumentTypesOf(exampleFn) == expected)
assert(argumentTypesOf(exampleFn(_, _)) == expected)
assert(argumentTypesOfNoWildCard(exampleFn) == expected)
assert(argumentTypesOfNoWildCard(exampleFn(_, _)) == expected)
assert(argumentTypesOfAllWildCard(exampleFn) == expected)
assert(argumentTypesOfAllWildCard(exampleFn(_, _)) == expected)
}
}