diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 79a4e9afbc7f..44ae2ac7a46c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1197,7 +1197,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer ) end typedIf - /** Decompose function prototype into a list of parameter prototypes and a result prototype + /** Decompose function prototype into a list of parameter prototypes, an optional list + * describing whether the parameter prototypes come from WildcardTypes, and a result prototype * tree, using WildcardTypes where a type is not known. * For the result type we do this even if the expected type is not fully * defined, which is a bit of a hack. But it's needed to make the following work @@ -1206,7 +1207,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer * def double(x: Char): String = s"$x$x" * "abc" flatMap double */ - private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = { + private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], Option[List[Boolean]], untpd.Tree) = { def typeTree(tp: Type) = tp match { case _: WildcardType => new untpd.InferredTypeTree() case _ => untpd.InferredTypeTree(tp) @@ -1234,18 +1235,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // if expected parameter type(s) are wildcards, approximate from below. // if expected result type is a wildcard, approximate from above. // this can type the greatest set of admissible closures. - (pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last))) + // However, we still keep the information on whether expected parameter types were + // wildcards, in case of types inferred from target being more specific + + val fromWildcards = pt1.argInfos.init.map{ + case bounds @ TypeBounds(nt, at) if nt == defn.NothingType && at == defn.AnyType => true + case bounds => false + } + + (pt1.argTypesLo.init, Some(fromWildcards), typeTree(interpolateWildcards(pt1.argTypesHi.last))) case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => - (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) + (formals, None, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) case SAMType(mt @ MethodTpe(_, formals, restpe)) => - (formals, + (formals, None, if (mt.isResultDependent) untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))) else typeTree(restpe)) case _ => - (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) + (List.tabulate(defaultArity)(alwaysWildcardType), None, untpd.TypeTree()) } } } @@ -1267,7 +1276,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer * If both attempts fail, return `NoType`. */ def inferredFromTarget( - param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type = + param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int, isWildcardParam: Boolean)(using Context): Type = val target = calleeType.widen match case mtpe: MethodType => val pos = paramIndex(param.name) @@ -1280,7 +1289,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else NoType case _ => NoType if target.exists then formal <:< target - if isFullyDefined(formal, ForceDegree.flipBottom) then formal + if !isWildcardParam && isFullyDefined(formal, ForceDegree.flipBottom) then formal else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target else NoType @@ -1457,7 +1466,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => } - val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) + val (protoFormals, areWildcardParams, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) def protoFormal(i: Int): Type = if (protoFormals.length == params.length) protoFormals(i) @@ -1500,13 +1509,21 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if (!param.tpt.isEmpty) param else val formal = protoFormal(i) + val isWildcardParam = areWildcardParams.map(list => if i < list.length then list(i) else false).getOrElse(false) val knownFormal = isFullyDefined(formal, ForceDegree.failBottom) - val paramType = - if knownFormal then formal - else inferredFromTarget(param, formal, calleeType, paramIndex) - .orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos)) + // Since decomposeProtoFunction eagerly approximates function arguments + // from below, then in the case that the argument was also identified as + // a wildcard type we try to prioritize inferring from target, if possible. + // See issue 16405 (tests/run/16405.scala) + val (usingFormal, paramType) = + if !isWildcardParam && knownFormal then (true, formal) + else + val fromTarget = inferredFromTarget(param, formal, calleeType, paramIndex, isWildcardParam) + if fromTarget.exists then (false, fromTarget) + else if knownFormal then (true, formal) + else (false, errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos)) val paramTpt = untpd.TypedSplice( - (if knownFormal then InferredTypeTree() else untpd.TypeTree()) + (if usingFormal then InferredTypeTree() else untpd.TypeTree()) .withType(paramType.translateFromRepeated(toArray = false)) .withSpan(param.span.endPos) ) @@ -1577,7 +1594,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt) } else { - val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos) + val (protoFormals, _, _) = decomposeProtoFunction(pt, 1, tree.srcPos) val checkMode = if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None else desugar.MatchCheck.Exhaustive diff --git a/tests/run/16405.scala b/tests/run/16405.scala new file mode 100644 index 000000000000..fa0681683c42 --- /dev/null +++ b/tests/run/16405.scala @@ -0,0 +1,31 @@ +import scala.compiletime.summonInline + +case class TypeDesc[T](tpe: String) +object TypeDesc { + given nothing: TypeDesc[Nothing] = TypeDesc("Nothing") + given string: TypeDesc[String] = TypeDesc("String") + given int: TypeDesc[Int] = TypeDesc("Int") +} + +def exampleFn(s: String, i: Int): Unit = () + +inline def argumentTypesOf[R](fun: (_, _) => R): (TypeDesc[?], TypeDesc[?]) = { + inline fun match { + case x: ((a, b) => R) => + (scala.compiletime.summonInline[TypeDesc[a]], scala.compiletime.summonInline[TypeDesc[b]]) + } +} +inline def argumentTypesOfNoWildCard[A, B, R](fun: (A, B) => R): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun) +inline def argumentTypesOfAllWildCard(fun: (?, ?) => ?): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun) + +object Test { + def main(args: Array[String]): Unit = { + val expected = (TypeDesc.string, TypeDesc.int) + assert(argumentTypesOf(exampleFn) == expected) + assert(argumentTypesOf(exampleFn(_, _)) == expected) + assert(argumentTypesOfNoWildCard(exampleFn) == expected) + assert(argumentTypesOfNoWildCard(exampleFn(_, _)) == expected) + assert(argumentTypesOfAllWildCard(exampleFn) == expected) + assert(argumentTypesOfAllWildCard(exampleFn(_, _)) == expected) + } +} \ No newline at end of file