Skip to content

Commit 17fb4f9

Browse files
committed
[Prototype] Better type inference for lambdas (e.g., as used in folds)
No version of Scala has ever been able to infer the following: val xs = List(1, 2, 3) xs.foldLeft(Nil)((acc, x) => x :: acc) To understand why, let's have a look at the signature of `List[A]#foldLeft`: def foldLeft[B](z: B)(op: (B, A) => B): B When typing the foldLeft call in the previous expression, the compiler starts by creating an unconstrained type variable ?B, the challenge is then to successfully type the expression and instantiate `?B := List[Int]`. Typing the first argument is easy: `Nil` is a valid argument if we add a constraint: ?B >: Nil.type Typing the second argument is where we get stuck normally: we need to choose a type for the binding `acc`, but `?B` is a type variable and not a fully-defined type, this is solved by instantiating `?B` to one of its bound, but no matter what bound we choose, the rest of the expression won't typecheck: - if we instantiate `?B := Nil.type`, then the body of the lambda `x :: acc` is not a subtype of the expected result type `?B`. - if we instantiate `?B := Any`, then the body of the lambda does not typecheck since there is no method `::` on `Any`. But... what if we just let `acc` have type `?B` without instantiating it first? This is not completely meaningless: `?B` behaves like an abstract type except that its bounds might be refined as we typecheck code, as long as narrowing holds (#), this should be safe. The remaining challenge then is to type the body of the lambda `x :: acc` which desugars to `acc.::(x)`, this won't typecheck as-is since `::` is not defined on the upper bound of `?B`, so we need to refine this upper bound somehow, the heuristic we use is: 1) Look for `::` in the lower bound of `?B >: Nil.type`, Nil does have such a member! 2) Find the class where this member is defined: it's `List` 3) If the class has type parameters, create one fresh type variable for each parameter slot, the resulting type is our new upper bound, so here we get `?B <: List[?X]` where `?X` is a fresh type variable. We can then proceed to type the body of the lambda: acc.::(x) This first creates a type variable `?B2 >: ?X`, because `::` has type: def :: [B >: A](elem: B): List[B] Because the result type of the lambda is `?B`, we get an additional constraint: List[?B2] <: ?B We know that `?B <: List[?X]` so this means that `?B2 <: ?X`, but we also know that `B2 >: ?X`, so we can instantiate `?B2 := ?X` and `?B := List[?X]`. Finally, because `x` has type Int we have `?B2 >: Int` which simplifies to: ?X >: Int Therefore, the result type of the foldLeft is `List[?X]` where `?X >: Int`, because `List` is covariant, we instantiate `?X := Int` to get the most precise result type `List[Int]`. Note that the the use of fresh type variables in 3) was crucial here: if we had instead used wildcards and added an upper bound `?B <: List[_]`, then we would have been able to type `acc.::(x)`, but the result would have type `List[Any]`, meaning the result of the foldLeft call would be `List[Any]` when we wanted `List[Int]`. - Is this actually sound? - Are there other compelling examples where this useful, besides folds? - Is the performance impact of this stuff acceptable? - How do we deal with overloads? - How do we deal with overrides? - How does this interact with implicit conversions? - How does this interact with implicit search in general, we might find one implicit at a given point, but then as we add more constraints to the same type variable, the same implicit search could find a different result. How big of a problem is that? (#): narrowing in fact does not hold when `@uncheckedVariance` is used, which is why we special-case it in `typedSelect` in this commit.
1 parent 0121d82 commit 17fb4f9

File tree

9 files changed

+347
-12
lines changed

9 files changed

+347
-12
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,12 @@ object desugar {
681681
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
682682
else mods
683683
}
684-
val appParamss =
684+
// FIXME: This now infers `List[List[DefTree]]`, the issue
685+
// is that `withMods` is defined in `DefTree` so that becomes the
686+
// upper bound of the type variable (see logic in `constrainSelectionQualifier`),
687+
// but the result type of `withMods` is a type member which is
688+
// refined in `ValDef`.
689+
val appParamss: List[List[ValDef]] =
685690
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
686691
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
687692
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

+5
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
310310
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =
311311

312312
def recur(tp: Type, fromBelow: Boolean): Type = tp match
313+
case tp: NamedType =>
314+
val underlying1 = recur(tp.underlying, fromBelow)
315+
if underlying1 ne tp.underlying then underlying1 else tp
313316
case tp: AndOrType =>
314317
val r1 = recur(tp.tp1, fromBelow)
315318
val r2 = recur(tp.tp2, fromBelow)
@@ -613,6 +616,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
613616
def occursAtToplevel(param: TypeParamRef, inst: Type)(implicit ctx: Context): Boolean =
614617

615618
def occurs(tp: Type)(using Context): Boolean = tp match
619+
case tp: NamedType =>
620+
occurs(tp.underlying)
616621
case tp: AndOrType =>
617622
occurs(tp.tp1) || occurs(tp.tp2)
618623
case tp: TypeParamRef =>

compiler/src/dotty/tools/dotc/core/Types.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3927,7 +3927,7 @@ object Types {
39273927
NoType
39283928
}
39293929

3930-
def tyconTypeParams(implicit ctx: Context): List[ParamInfo] = {
3930+
def tyconTypeParams(implicit ctx: Context): List[TypeApplications.TypeParamInfo] = {
39313931
val tparams = tycon.typeParams
39323932
if (tparams.isEmpty) HKTypeLambda.any(args.length).typeParams else tparams
39333933
}

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ object Inferencing {
381381
*
382382
* we want to instantiate U to x.type right away. No need to wait further.
383383
*/
384-
private def variances(tp: Type)(using Context): VarianceMap = {
384+
def variances(tp: Type)(using Context): VarianceMap = {
385385
Stats.record("variances")
386386
val constraint = ctx.typerState.constraint
387387

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+21-6
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,12 @@ object ProtoTypes {
135135
* or as an upper bound of a prefix or underlying type.
136136
*/
137137
private def hasUnknownMembers(tp: Type)(using Context): Boolean = tp match {
138-
case tp: TypeVar => !tp.isInstantiated
138+
case tp: TypeVar =>
139+
// FIXME: This used to be `!tp.isInstantiated` but that prevents
140+
// extension methods from being selected with the changes in this PR.
141+
// This change doesn't break any testcase, can we construct a testcase
142+
// where this matters?
143+
false
139144
case tp: WildcardType => true
140145
case NoType => true
141146
case tp: TypeRef =>
@@ -152,20 +157,30 @@ object ProtoTypes {
152157
case _ => false
153158
}
154159

155-
override def isMatchedBy(tp1: Type, keepConstraint: Boolean)(using Context): Boolean =
156-
name == nme.WILDCARD || hasUnknownMembers(tp1) ||
157-
{
158-
val mbr = if (privateOK) tp1.member(name) else tp1.nonPrivateMember(name)
160+
override def isMatchedBy(tp1: Type, keepConstraint: Boolean)(using Context): Boolean = {
161+
if name == nme.WILDCARD || hasUnknownMembers(tp1) then
162+
return true
163+
164+
def go(pre: Type): Boolean = {
165+
val mbr = if (privateOK) pre.member(name) else pre.nonPrivateMember(name)
159166
def qualifies(m: SingleDenotation) =
160167
memberProto.isRef(defn.UnitClass) ||
161-
tp1.isValueType && compat.normalizedCompatible(NamedType(tp1, name, m), memberProto, keepConstraint)
168+
pre.isValueType && compat.normalizedCompatible(NamedType(pre, name, m), memberProto, keepConstraint)
162169
// Note: can't use `m.info` here because if `m` is a method, `m.info`
163170
// loses knowledge about `m`'s default arguments.
164171
mbr match { // hasAltWith inlined for performance
165172
case mbr: SingleDenotation => mbr.exists && qualifies(mbr)
166173
case _ => mbr hasAltWith qualifies
167174
}
168175
}
176+
tp1.widenDealias.stripTypeVar match {
177+
case tp: TypeParamRef =>
178+
val bounds = ctx.typeComparer.bounds(tp)
179+
go(bounds.hi) || go(bounds.lo)
180+
case _ =>
181+
go(tp1)
182+
}
183+
}
169184

170185
def underlying(using Context): Type = WildcardType
171186

compiler/src/dotty/tools/dotc/typer/Typer.scala

+233-3
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,220 @@ class Typer extends Namer
516516
tree
517517
}
518518

519+
/** Try to add constraints to type a selection where the qualifier is a type variable.
520+
*
521+
* Currently, this should only happen with lambdas, for example when typechecking:
522+
*
523+
* def foo[T <: List[Any]](x: T => T)
524+
* foo(x => x.head: Int)
525+
*
526+
* In the past, `typedFunctionValue` would have instantiated the type
527+
* variable corresponding to the type parameter `T` to `List[Any]` before
528+
* typing the lambda, which would then fail because `x.head` has type `Any`.
529+
* But we now leave such type variables uninstantiated, which means we need
530+
* to figure out how to type a selection where the prefix is an
531+
* uninstantiated type variable, and in particular how to propagate
532+
* constraints from typing this selection back to that type variable.
533+
*
534+
* @param qual The type of the qualifier of the selection
535+
* @param name The name of the member being selected
536+
* @param underlyingVar The uninstantiated type variable underlying the type of the qualifier
537+
* @param pt The expected type of the selection
538+
*/
539+
private def constrainSelectionQualifier(
540+
qual: Type, name: Name, underlyingVar: TypeVar, pt: Type)(using Context): Boolean = {
541+
542+
/** Return `tycon[?A, ?B, ...]` where `?A`, `?B`, ... are fresh type variables
543+
* conforming to the corresponding type parameter in `tparams`.
544+
*/
545+
def appliedWithVars(tycon: Type, tparams: List[TypeApplications.TypeParamInfo]): Type = {
546+
if (tparams.isEmpty)
547+
tycon
548+
else {
549+
val tl = tycon.EtaExpand(tparams).asInstanceOf[HKTypeLambda]
550+
val tvars = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)._2.map(_.tpe)
551+
tycon.appliedTo(tvars)
552+
}
553+
}
554+
555+
/** Replace all applied types `tycon[T, S, ...]` by `tycon[?A, ?B, ...]`
556+
* where `?A`, `?B`, ... are fresh type variables.
557+
*/
558+
def replaceArgsByVars = new TypeMap {
559+
def apply(t: Type): Type = t match {
560+
case tp: TypeLambda =>
561+
tp
562+
case tp @ AppliedType(tycon, args) =>
563+
// Note that we don't constrain the fresh type variables
564+
// such that the mapped type is a subtype of `tp`, we let
565+
// the caller deal with that.
566+
appliedWithVars(tycon, tp.tyconTypeParams)
567+
case _ =>
568+
mapOver(t)
569+
}
570+
}
571+
572+
/** Does `@uncheckedVariance` appears somewhere in the type of `d` ? */
573+
def hasUncheckedVariance(d: SingleDenotation) = d.info.widen.existsPart {
574+
case tp @ AnnotatedType(_, annot) =>
575+
annot.symbol eq defn.UncheckedVarianceAnnot
576+
case tp =>
577+
false
578+
}
579+
580+
/** The members of one of the bound of `underlyingVar` which the selection
581+
* could resolve to.
582+
*
583+
* @param isUpper If true, look for candidates in the upper bound,
584+
* otherwise look in the lower bound.
585+
*/
586+
def candidatesInBound(isUpper: Boolean): List[SingleDenotation] = {
587+
val bounds = ctx.typeComparer.bounds(underlyingVar.origin)
588+
val bound = if (isUpper) bounds.hi else bounds.lo
589+
val d = bound.member(name)
590+
d.alternatives
591+
}
592+
593+
/** Try to add additional constraints on `underlyingVar`
594+
* to allow a selection of `candidate` to typecheck.
595+
*
596+
* @param isUpper Does `candidate` come from the upper bound
597+
* of the qualifier type?
598+
*/
599+
def constrainTo(candidate: SingleDenotation, isUpper: Boolean): Boolean = {
600+
if (hasUncheckedVariance(candidate)) {
601+
// If `@uncheckedVariance` appears in the type of the candidate, give up
602+
// on delaying instantiation and just instantiate the type variable at
603+
// that point. If we don't do that, the type variable might later
604+
// be constrained in a way that prevents the selection from typechecking,
605+
// because narrowing does not hold with unchecked variance.
606+
// See tests/pos/fold-infer-uncheckedVariance.scala for an example
607+
// which did not pass `-Ytest-pickler` before.
608+
// FIXME: `-Ycheck:all` did not pick up on this issue in
609+
// fold-infer-uncheckedVariance.scala because the ReTyper never retypes
610+
// selections, I think it's important to get TreeChecker to actually
611+
// verify this stuff, especially with everything going on in this PR.
612+
underlyingVar.instantiate(fromBelow = !isUpper)
613+
return true
614+
}
615+
616+
val owner = candidate.symbol.maybeOwner
617+
// TODO: Deal with methods in structural types?
618+
if (!owner.exists || !owner.isClass)
619+
return false
620+
621+
if (isUpper) {
622+
// The candidate comes from the upper bound of the qualifier.
623+
// In that case we replace type arguments in the upper bound
624+
// by fresh type variables to make it more flexible.
625+
//
626+
// For example, we might have `qual: ?T` where `?T <: List[AnyVal]`.
627+
// in which case `qual.head` will have type `qual.A` where `A`
628+
// is an abstract type >: Nothing <: AnyVal.
629+
// Therefore, when typechecking `qual.head: Int`, we get:
630+
//
631+
// qual.A <:< Int
632+
// AnyVal <:< Int
633+
// false
634+
//
635+
// The problem is that subtype checks on `qual.A` do
636+
// not allow us to constraint `?T` further.
637+
//
638+
// To fix this, we need a more precise upper-bound for `?T`:
639+
// we can safely rewrite the constraint:
640+
//
641+
// ?T <: List[AnyVal]
642+
//
643+
// as:
644+
//
645+
// ?T <: List[?X]
646+
// ?X <: AnyVal
647+
//
648+
// Now, if we try to typecheck `qual.head: Int`, we get:
649+
//
650+
// qual.A <:< Int
651+
// ?X <:< Int
652+
// true, with extra constraint `?X <: Int`
653+
//
654+
// And at a later point, `?T` will be instantiated to a
655+
// subtype of `List[Int]` as expected.
656+
657+
val base = qual.baseType(owner)
658+
// FIXME: this is wasteful: if we have multiple selections with the
659+
// same qualifier, we'll create fresh type variables every time.
660+
val newUpperBound = replaceArgsByVars(base)
661+
662+
if newUpperBound ne base then
663+
underlyingVar <:< newUpperBound
664+
} else {
665+
// The candidate comes from the lower bound of the qualifier.
666+
// In that case, we need to constrain the upper bound of the
667+
// qualifier to be able to typecheck the selection at all,
668+
// and like in the isUpper case, we want type variables in
669+
// the arguments of that upper bound for flexibility.
670+
//
671+
// For example, if we have `qual: ?T` where `?T >: Nil`, then
672+
// `qual.::` will fail as there is no member named `::`
673+
// defined on `Any`, so we need to further constrain the upper
674+
// bound. We know that `::` is defined on `List`, so we can add
675+
// a constraint:
676+
//
677+
// ?T <: List[?X]
678+
// ?X
679+
//
680+
// (in this example, the fresh type variable `?X` can stay
681+
// unconstrained since `Nil <:< List[?X]` is true for all `?X`)
682+
683+
// FIXME: better handling of overrides: `candidate` might be an
684+
// override of some member defined in a parent class, in which
685+
// case we're overconstraining the upper bound.
686+
val newUpperBound = appliedWithVars(owner.typeRef, owner.typeParams)
687+
underlyingVar <:< newUpperBound
688+
}
689+
690+
// FIXME: it would be nice if we could use the expected type
691+
// to filter out some candidates, but it's hard to rule out
692+
// anything since some implicit conversion might kick in
693+
// during adaptation.
694+
true
695+
}
696+
697+
/** Try to add additional constraints on `underlyingVar`
698+
* to allow a selection based on the candidates found
699+
* in one of its own bound.
700+
*
701+
* @param isUpper If true, look for candidates in the upper bound,
702+
* otherwise look in the lower bound.
703+
*/
704+
def constrainInBound(isUpper: Boolean): Boolean = {
705+
// FIXME: we just stop after finding a matching candidate, should we
706+
// take the union of the constraints they add instead?
707+
candidatesInBound(isUpper).exists(constrainTo(_, isUpper))
708+
}
709+
710+
// FIXME: We currently only look at the lower bound if we don't find a
711+
// matching member in the upper bound, but that could exclude
712+
// the right candidate.
713+
constrainInBound(isUpper = true) || constrainInBound(isUpper = false)
714+
}
715+
519716
def typedSelect(tree: untpd.Select, pt: Type, qual: Tree)(using Context): Tree = qual match {
520717
case qual @ IntegratedTypeArgs(app) =>
521718
pt.revealIgnored match {
522719
case _: PolyProto => qual // keep the IntegratedTypeArgs to strip at next typedTypeApply
523720
case _ => app
524721
}
525722
case qual =>
723+
qual.tpe.widenDealias.stripTypeVar match {
724+
case tp: TypeParamRef =>
725+
ctx.typerState.constraint.typeVarOfParam(tp) match {
726+
case tvar: TypeVar =>
727+
constrainSelectionQualifier(qual.tpe, tree.name, tvar, pt)
728+
case _ =>
729+
}
730+
case _ =>
731+
}
732+
526733
val select = assignType(cpy.Select(tree)(qual, tree.name), qual)
527734
val select1 = toNotNullTermRef(select, pt)
528735

@@ -1096,6 +1303,21 @@ class Typer extends Namer
10961303
case _ =>
10971304
}
10981305

1306+
// The set of type variables in the prototype which appear only in covariant or
1307+
// contravariant positions. These should be instantiatable without
1308+
// preventing the body of the lambda from typechecking (...except in situations
1309+
// like `def foo[T, U <: T](x: T => U)`, where instantiating `T` to a specific
1310+
// type might overconstrain `U`).
1311+
//
1312+
// This doesn't exclude type variables which appear with a different
1313+
// variance at a later point in the same method call, or a subsequent chained
1314+
// call.
1315+
//
1316+
// TODO: try to replace this by an empty list and see how that affects
1317+
// inference and performance (we would end up creating a lot more type
1318+
// variables in `typedSelect`).
1319+
lazy val protoVariantVars = variances(pt).toList.filter(_._2 != 0).map(_._1)
1320+
10991321
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
11001322

11011323
/** The inferred parameter type for a parameter in a lambda that does
@@ -1118,7 +1340,7 @@ class Typer extends Namer
11181340
* If all attempts fail, issue a "missing parameter type" error.
11191341
*/
11201342
def inferredParamType(param: untpd.ValDef, formal: Type): Type =
1121-
if isFullyDefined(formal, ForceDegree.failBottom) then return formal
1343+
if isFullyDefined(formal, ForceDegree.none) then return formal
11221344
val target = calleeType.widen match
11231345
case mtpe: MethodType =>
11241346
val pos = paramIndex(param.name)
@@ -1128,9 +1350,17 @@ class Typer extends Namer
11281350
else NoType
11291351
case _ => NoType
11301352
if target.exists then formal <:< target
1131-
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1353+
// if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1354+
if isFullyDefined(formal, ForceDegree.none) then formal
11321355
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
1133-
else errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
1356+
else if !formal.isInstanceOf[WildcardType] then
1357+
instantiateSelected(formal, protoVariantVars)
1358+
// Intentionally leave uninstantiated type variables in the types of parameters,
1359+
// this works because `typedSelect` special cases the handling of qualifiers
1360+
// whose type is a type variable.
1361+
formal
1362+
else
1363+
errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
11341364

11351365
def protoFormal(i: Int): Type =
11361366
if (protoFormals.length == params.length) protoFormals(i)

0 commit comments

Comments
 (0)