From 4523a6f4c7acc2811da3643d95b27768c08635b5 Mon Sep 17 00:00:00 2001 From: Jan Chyb Date: Thu, 5 Jan 2023 13:23:55 +0100 Subject: [PATCH] Fix #16405 - wildcards prematurely resolving to Nothing This was a problem because it could it get in the way of some metaprogramming techniques. The main issue was the fact that when typing functions, the type inference would first look at the types from the source method (resolving type wildcards to Nothing) and only after that, it could look at the target method. Now, in the case of wildcards we save that fact for later (while still resolving the prototype parameter to Nothing) and we in that case we prioritize according to the target method, after which we fallback to the default procedure. --- .../src/dotty/tools/dotc/typer/Typer.scala | 47 +++++++++++++------ tests/run/16405.scala | 31 ++++++++++++ 2 files changed, 63 insertions(+), 15 deletions(-) create mode 100644 tests/run/16405.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 79a4e9afbc7f..44ae2ac7a46c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1197,7 +1197,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer ) end typedIf - /** Decompose function prototype into a list of parameter prototypes and a result prototype + /** Decompose function prototype into a list of parameter prototypes, an optional list + * describing whether the parameter prototypes come from WildcardTypes, and a result prototype * tree, using WildcardTypes where a type is not known. * For the result type we do this even if the expected type is not fully * defined, which is a bit of a hack. But it's needed to make the following work @@ -1206,7 +1207,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer * def double(x: Char): String = s"$x$x" * "abc" flatMap double */ - private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = { + private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], Option[List[Boolean]], untpd.Tree) = { def typeTree(tp: Type) = tp match { case _: WildcardType => new untpd.InferredTypeTree() case _ => untpd.InferredTypeTree(tp) @@ -1234,18 +1235,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // if expected parameter type(s) are wildcards, approximate from below. // if expected result type is a wildcard, approximate from above. // this can type the greatest set of admissible closures. - (pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last))) + // However, we still keep the information on whether expected parameter types were + // wildcards, in case of types inferred from target being more specific + + val fromWildcards = pt1.argInfos.init.map{ + case bounds @ TypeBounds(nt, at) if nt == defn.NothingType && at == defn.AnyType => true + case bounds => false + } + + (pt1.argTypesLo.init, Some(fromWildcards), typeTree(interpolateWildcards(pt1.argTypesHi.last))) case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => - (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) + (formals, None, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) case SAMType(mt @ MethodTpe(_, formals, restpe)) => - (formals, + (formals, None, if (mt.isResultDependent) untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))) else typeTree(restpe)) case _ => - (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) + (List.tabulate(defaultArity)(alwaysWildcardType), None, untpd.TypeTree()) } } } @@ -1267,7 +1276,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer * If both attempts fail, return `NoType`. */ def inferredFromTarget( - param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type = + param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int, isWildcardParam: Boolean)(using Context): Type = val target = calleeType.widen match case mtpe: MethodType => val pos = paramIndex(param.name) @@ -1280,7 +1289,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else NoType case _ => NoType if target.exists then formal <:< target - if isFullyDefined(formal, ForceDegree.flipBottom) then formal + if !isWildcardParam && isFullyDefined(formal, ForceDegree.flipBottom) then formal else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target else NoType @@ -1457,7 +1466,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => } - val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) + val (protoFormals, areWildcardParams, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) def protoFormal(i: Int): Type = if (protoFormals.length == params.length) protoFormals(i) @@ -1500,13 +1509,21 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if (!param.tpt.isEmpty) param else val formal = protoFormal(i) + val isWildcardParam = areWildcardParams.map(list => if i < list.length then list(i) else false).getOrElse(false) val knownFormal = isFullyDefined(formal, ForceDegree.failBottom) - val paramType = - if knownFormal then formal - else inferredFromTarget(param, formal, calleeType, paramIndex) - .orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos)) + // Since decomposeProtoFunction eagerly approximates function arguments + // from below, then in the case that the argument was also identified as + // a wildcard type we try to prioritize inferring from target, if possible. + // See issue 16405 (tests/run/16405.scala) + val (usingFormal, paramType) = + if !isWildcardParam && knownFormal then (true, formal) + else + val fromTarget = inferredFromTarget(param, formal, calleeType, paramIndex, isWildcardParam) + if fromTarget.exists then (false, fromTarget) + else if knownFormal then (true, formal) + else (false, errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos)) val paramTpt = untpd.TypedSplice( - (if knownFormal then InferredTypeTree() else untpd.TypeTree()) + (if usingFormal then InferredTypeTree() else untpd.TypeTree()) .withType(paramType.translateFromRepeated(toArray = false)) .withSpan(param.span.endPos) ) @@ -1577,7 +1594,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt) } else { - val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos) + val (protoFormals, _, _) = decomposeProtoFunction(pt, 1, tree.srcPos) val checkMode = if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None else desugar.MatchCheck.Exhaustive diff --git a/tests/run/16405.scala b/tests/run/16405.scala new file mode 100644 index 000000000000..fa0681683c42 --- /dev/null +++ b/tests/run/16405.scala @@ -0,0 +1,31 @@ +import scala.compiletime.summonInline + +case class TypeDesc[T](tpe: String) +object TypeDesc { + given nothing: TypeDesc[Nothing] = TypeDesc("Nothing") + given string: TypeDesc[String] = TypeDesc("String") + given int: TypeDesc[Int] = TypeDesc("Int") +} + +def exampleFn(s: String, i: Int): Unit = () + +inline def argumentTypesOf[R](fun: (_, _) => R): (TypeDesc[?], TypeDesc[?]) = { + inline fun match { + case x: ((a, b) => R) => + (scala.compiletime.summonInline[TypeDesc[a]], scala.compiletime.summonInline[TypeDesc[b]]) + } +} +inline def argumentTypesOfNoWildCard[A, B, R](fun: (A, B) => R): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun) +inline def argumentTypesOfAllWildCard(fun: (?, ?) => ?): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun) + +object Test { + def main(args: Array[String]): Unit = { + val expected = (TypeDesc.string, TypeDesc.int) + assert(argumentTypesOf(exampleFn) == expected) + assert(argumentTypesOf(exampleFn(_, _)) == expected) + assert(argumentTypesOfNoWildCard(exampleFn) == expected) + assert(argumentTypesOfNoWildCard(exampleFn(_, _)) == expected) + assert(argumentTypesOfAllWildCard(exampleFn) == expected) + assert(argumentTypesOfAllWildCard(exampleFn(_, _)) == expected) + } +} \ No newline at end of file