Skip to content

Commit e15c580

Browse files
committed
[Prototype] Better type inference for lambdas (e.g., as used in folds)
No version of Scala has ever been able to infer the following: val xs = List(1, 2, 3) xs.foldLeft(Nil)((acc, x) => x :: acc) To understand why, let's have a look at the signature of `List[A]#foldLeft`: def foldLeft[B](z: B)(op: (B, A) => B): B When typing the foldLeft call in the previous expression, the compiler starts by creating an unconstrained type variable ?B, the challenge is then to successfully type the expression and instantiate `?B := List[Int]`. Typing the first argument is easy: `Nil` is a valid argument if we add a constraint: ?B >: Nil.type Typing the second argument is where we get stuck normally: we need to choose a type for the binding `acc`, but `?B` is a type variable and not a fully-defined type, this is solved by instantiating `?B` to one of its bound, but no matter what bound we choose, the rest of the expression won't typecheck: - if we instantiate `?B := Nil.type`, then the body of the lambda `x :: acc` is not a subtype of the expected result type `?B`. - if we instantiate `?B := Any`, then the body of the lambda does not typecheck since there is no method `::` on `Any`. But... what if we just let `acc` have type `?B` without instantiating it first? This is not completely meaningless: `?B` behaves like an abstract type except that its bounds might be refined as we typecheck code, as long as narrowing holds (#), this should be safe. The remaining challenge then is to type the body of the lambda `x :: acc` which desugars to `acc.::(x)`, this won't typecheck as-is since `::` is not defined on the upper bound of `?B`, so we need to refine this upper bound somehow, the heuristic we use is: 1) Look for `::` in the lower bound of `?B >: Nil.type`, Nil does have such a member! 2) Find the class where this member is defined: it's `List` 3) If the class has type parameters, create one fresh type variable for each parameter slot, the resulting type is our new upper bound, so here we get `?B <: List[?X]` where `?X` is a fresh type variable. We can then proceed to type the body of the lambda: acc.::(x) This first creates a type variable `?B2 >: ?X`, because `::` has type: def :: [B >: A](elem: B): List[B] Because the result type of the lambda is `?B`, we get an additional constraint: List[?B2] <: ?B We know that `?B <: List[?X]` so this means that `?B2 <: ?X`, but we also know that `B2 >: ?X`, so we can instantiate `?B2 := ?X` and `?B := List[?X]`. Finally, because `x` has type Int we have `?B2 >: Int` which simplifies to: ?X >: Int Therefore, the result type of the foldLeft is `List[?X]` where `?X >: Int`, because `List` is covariant, we instantiate `?X := Int` to get the most precise result type `List[Int]`. Note that the the use of fresh type variables in 3) was crucial here: if we had instead used wildcards and added an upper bound `?B <: List[_]`, then we would have been able to type `acc.::(x)`, but the result would have type `List[Any]`, meaning the result of the foldLeft call would be `List[Any]` when we wanted `List[Int]`. \# Status All the compiler tests pass, including bootstrapping, but one of third of the community build breaks currently. Even if this PR never makes it in, it has been very useful for stress-testing our constraint solver and lead to several PRs I made over the past few days: of this PR that would be worth getting in by themselves. \# Open questions - Is this actually sound? - Are there other compelling examples where this useful, besides folds? - Is the performance impact of this stuff acceptable? - How do we deal with overloads? - How do we deal with overrides? - How does this interact with implicit conversions? - How does this interact with implicit search in general, we might find one implicit at a given point, but then as we add more constraints to the same type variable, the same implicit search could find a different result. How big of a problem is that? (#): narrowing in fact does not hold when `@uncheckedVariance` is used, which is why we special-case it in `typedSelect` in this commit.
1 parent 0121d82 commit e15c580

File tree

9 files changed

+347
-12
lines changed

9 files changed

+347
-12
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,12 @@ object desugar {
681681
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
682682
else mods
683683
}
684-
val appParamss =
684+
// FIXME: This now infers `List[List[DefTree]]`, the issue
685+
// is that `withMods` is defined in `DefTree` so that becomes the
686+
// upper bound of the type variable (see logic in `constrainSelectionQualifier`),
687+
// but the result type of `withMods` is a type member which is
688+
// refined in `ValDef`.
689+
val appParamss: List[List[ValDef]] =
685690
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
686691
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
687692
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

+5
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
310310
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =
311311

312312
def recur(tp: Type, fromBelow: Boolean): Type = tp match
313+
case tp: NamedType =>
314+
val underlying1 = recur(tp.underlying, fromBelow)
315+
if underlying1 ne tp.underlying then underlying1 else tp
313316
case tp: AndOrType =>
314317
val r1 = recur(tp.tp1, fromBelow)
315318
val r2 = recur(tp.tp2, fromBelow)
@@ -613,6 +616,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
613616
def occursAtToplevel(param: TypeParamRef, inst: Type)(implicit ctx: Context): Boolean =
614617

615618
def occurs(tp: Type)(using Context): Boolean = tp match
619+
case tp: NamedType =>
620+
occurs(tp.underlying)
616621
case tp: AndOrType =>
617622
occurs(tp.tp1) || occurs(tp.tp2)
618623
case tp: TypeParamRef =>

compiler/src/dotty/tools/dotc/core/Types.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3927,7 +3927,7 @@ object Types {
39273927
NoType
39283928
}
39293929

3930-
def tyconTypeParams(implicit ctx: Context): List[ParamInfo] = {
3930+
def tyconTypeParams(implicit ctx: Context): List[TypeApplications.TypeParamInfo] = {
39313931
val tparams = tycon.typeParams
39323932
if (tparams.isEmpty) HKTypeLambda.any(args.length).typeParams else tparams
39333933
}

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ object Inferencing {
381381
*
382382
* we want to instantiate U to x.type right away. No need to wait further.
383383
*/
384-
private def variances(tp: Type)(using Context): VarianceMap = {
384+
def variances(tp: Type)(using Context): VarianceMap = {
385385
Stats.record("variances")
386386
val constraint = ctx.typerState.constraint
387387

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+21-6
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,12 @@ object ProtoTypes {
135135
* or as an upper bound of a prefix or underlying type.
136136
*/
137137
private def hasUnknownMembers(tp: Type)(using Context): Boolean = tp match {
138-
case tp: TypeVar => !tp.isInstantiated
138+
case tp: TypeVar =>
139+
// FIXME: This used to be `!tp.isInstantiated` but that prevents
140+
// extension methods from being selected with the changes in this PR.
141+
// This change doesn't break any testcase, can we construct a testcase
142+
// where this matters?
143+
false
139144
case tp: WildcardType => true
140145
case NoType => true
141146
case tp: TypeRef =>
@@ -152,20 +157,30 @@ object ProtoTypes {
152157
case _ => false
153158
}
154159

155-
override def isMatchedBy(tp1: Type, keepConstraint: Boolean)(using Context): Boolean =
156-
name == nme.WILDCARD || hasUnknownMembers(tp1) ||
157-
{
158-
val mbr = if (privateOK) tp1.member(name) else tp1.nonPrivateMember(name)
160+
override def isMatchedBy(tp1: Type, keepConstraint: Boolean)(using Context): Boolean = {
161+
if name == nme.WILDCARD || hasUnknownMembers(tp1) then
162+
return true
163+
164+
def go(pre: Type): Boolean = {
165+
val mbr = if (privateOK) pre.member(name) else pre.nonPrivateMember(name)
159166
def qualifies(m: SingleDenotation) =
160167
memberProto.isRef(defn.UnitClass) ||
161-
tp1.isValueType && compat.normalizedCompatible(NamedType(tp1, name, m), memberProto, keepConstraint)
168+
pre.isValueType && compat.normalizedCompatible(NamedType(pre, name, m), memberProto, keepConstraint)
162169
// Note: can't use `m.info` here because if `m` is a method, `m.info`
163170
// loses knowledge about `m`'s default arguments.
164171
mbr match { // hasAltWith inlined for performance
165172
case mbr: SingleDenotation => mbr.exists && qualifies(mbr)
166173
case _ => mbr hasAltWith qualifies
167174
}
168175
}
176+
tp1.widenDealias.stripTypeVar match {
177+
case tp: TypeParamRef =>
178+
val bounds = ctx.typeComparer.bounds(tp)
179+
go(bounds.hi) || go(bounds.lo)
180+
case _ =>
181+
go(tp1)
182+
}
183+
}
169184

170185
def underlying(using Context): Type = WildcardType
171186

compiler/src/dotty/tools/dotc/typer/Typer.scala

+233-3
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,220 @@ class Typer extends Namer
516516
tree
517517
}
518518

519+
/** Try to add constraints to type a selection where the qualifier is a type variable.
520+
*
521+
* Currently, this should only happen with lambdas, for example when typechecking:
522+
*
523+
* def foo[T <: List[Any]](x: T => T)
524+
* foo(x => x.head: Int)
525+
*
526+
* In the past, `typedFunctionValue` would have instantiated the type
527+
* variable corresponding to the type parameter `T` to `List[Any]` before
528+
* typing the lambda, which would then fail because `x.head` has type `Any`.
529+
* But we now leave such type variables uninstantiated, which means we need
530+
* to figure out how to type a selection where the prefix is an
531+
* uninstantiated type variable, and in particular how to propagate
532+
* constraints from typing this selection back to that type variable.
533+
*
534+
* @param qual The type of the qualifier of the selection
535+
* @param name The name of the member being selected
536+
* @param underlyingVar The uninstantiated type variable underlying the type of the qualifier
537+
* @param pt The expected type of the selection
538+
*/
539+
private def constrainSelectionQualifier(
540+
qual: Type, name: Name, underlyingVar: TypeVar, pt: Type)(using Context): Boolean = {
541+
542+
/** Return `tycon[?A, ?B, ...]` where `?A`, `?B`, ... are fresh type variables
543+
* conforming to the corresponding type parameter in `tparams`.
544+
*/
545+
def appliedWithVars(tycon: Type, tparams: List[TypeApplications.TypeParamInfo]): Type = {
546+
if (tparams.isEmpty)
547+
tycon
548+
else {
549+
val tl = tycon.EtaExpand(tparams).asInstanceOf[HKTypeLambda]
550+
val tvars = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)._2.map(_.tpe)
551+
tycon.appliedTo(tvars)
552+
}
553+
}
554+
555+
/** Replace all applied types `tycon[T, S, ...]` by `tycon[?A, ?B, ...]`
556+
* where `?A`, `?B`, ... are fresh type variables.
557+
*/
558+
def replaceArgsByVars = new TypeMap {
559+
def apply(t: Type): Type = t match {
560+
case tp: TypeLambda =>
561+
tp
562+
case tp @ AppliedType(tycon, args) =>
563+
// Note that we don't constrain the fresh type variables
564+
// such that the mapped type is a subtype of `tp`, we let
565+
// the caller deal with that.
566+
appliedWithVars(tycon, tp.tyconTypeParams)
567+
case _ =>
568+
mapOver(t)
569+
}
570+
}
571+
572+
/** Does `@uncheckedVariance` appears somewhere in the type of `d` ? */
573+
def hasUncheckedVariance(d: SingleDenotation) = d.info.widen.existsPart {
574+
case tp @ AnnotatedType(_, annot) =>
575+
annot.symbol eq defn.UncheckedVarianceAnnot
576+
case tp =>
577+
false
578+
}
579+
580+
/** The members of one of the bound of `underlyingVar` which the selection
581+
* could resolve to.
582+
*
583+
* @param isUpper If true, look for candidates in the upper bound,
584+
* otherwise look in the lower bound.
585+
*/
586+
def candidatesInBound(isUpper: Boolean): List[SingleDenotation] = {
587+
val bounds = ctx.typeComparer.bounds(underlyingVar.origin)
588+
val bound = if (isUpper) bounds.hi else bounds.lo
589+
val d = bound.member(name)
590+
d.alternatives
591+
}
592+
593+
/** Try to add additional constraints on `underlyingVar`
594+
* to allow a selection of `candidate` to typecheck.
595+
*
596+
* @param isUpper Does `candidate` come from the upper bound
597+
* of the qualifier type?
598+
*/
599+
def constrainTo(candidate: SingleDenotation, isUpper: Boolean): Boolean = {
600+
if (hasUncheckedVariance(candidate)) {
601+
// If `@uncheckedVariance` appears in the type of the candidate, give up
602+
// on delaying instantiation and just instantiate the type variable at
603+
// that point. If we don't do that, the type variable might later
604+
// be constrained in a way that prevents the selection from typechecking,
605+
// because narrowing does not hold with unchecked variance.
606+
// See tests/pos/fold-infer-uncheckedVariance.scala for an example
607+
// which did not pass `-Ytest-pickler` before.
608+
// FIXME: `-Ycheck:all` did not pick up on this issue in
609+
// fold-infer-uncheckedVariance.scala because the ReTyper never retypes
610+
// selections, I think it's important to get TreeChecker to actually
611+
// verify this stuff, especially with everything going on in this PR.
612+
underlyingVar.instantiate(fromBelow = !isUpper)
613+
return true
614+
}
615+
616+
val owner = candidate.symbol.maybeOwner
617+
// TODO: Deal with methods in structural types?
618+
if (!owner.exists || !owner.isClass)
619+
return false
620+
621+
if (isUpper) {
622+
// The candidate comes from the upper bound of the qualifier.
623+
// In that case we replace type arguments in the upper bound
624+
// by fresh type variables to make it more flexible.
625+
//
626+
// For example, we might have `qual: ?T` where `?T <: List[AnyVal]`.
627+
// in which case `qual.head` will have type `qual.A` where `A`
628+
// is an abstract type >: Nothing <: AnyVal.
629+
// Therefore, when typechecking `qual.head: Int`, we get:
630+
//
631+
// qual.A <:< Int
632+
// AnyVal <:< Int
633+
// false
634+
//
635+
// The problem is that subtype checks on `qual.A` do
636+
// not allow us to constraint `?T` further.
637+
//
638+
// To fix this, we need a more precise upper-bound for `?T`:
639+
// we can safely rewrite the constraint:
640+
//
641+
// ?T <: List[AnyVal]
642+
//
643+
// as:
644+
//
645+
// ?T <: List[?X]
646+
// ?X <: AnyVal
647+
//
648+
// Now, if we try to typecheck `qual.head: Int`, we get:
649+
//
650+
// qual.A <:< Int
651+
// ?X <:< Int
652+
// true, with extra constraint `?X <: Int`
653+
//
654+
// And at a later point, `?T` will be instantiated to a
655+
// subtype of `List[Int]` as expected.
656+
657+
val base = qual.baseType(owner)
658+
// FIXME: this is wasteful: if we have multiple selections with the
659+
// same qualifier, we'll create fresh type variables every time.
660+
val newUpperBound = replaceArgsByVars(base)
661+
662+
if newUpperBound ne base then
663+
underlyingVar <:< newUpperBound
664+
} else {
665+
// The candidate comes from the lower bound of the qualifier.
666+
// In that case, we need to constrain the upper bound of the
667+
// qualifier to be able to typecheck the selection at all,
668+
// and like in the isUpper case, we want type variables in
669+
// the arguments of that upper bound for flexibility.
670+
//
671+
// For example, if we have `qual: ?T` where `?T >: Nil`, then
672+
// `qual.::` will fail as there is no member named `::`
673+
// defined on `Any`, so we need to further constrain the upper
674+
// bound. We know that `::` is defined on `List`, so we can add
675+
// a constraint:
676+
//
677+
// ?T <: List[?X]
678+
// ?X
679+
//
680+
// (in this example, the fresh type variable `?X` can stay
681+
// unconstrained since `Nil <:< List[?X]` is true for all `?X`)
682+
683+
// FIXME: better handling of overrides: `candidate` might be an
684+
// override of some member defined in a parent class, in which
685+
// case we're overconstraining the upper bound.
686+
val newUpperBound = appliedWithVars(owner.typeRef, owner.typeParams)
687+
underlyingVar <:< newUpperBound
688+
}
689+
690+
// FIXME: it would be nice if we could use the expected type
691+
// to filter out some candidates, but it's hard to rule out
692+
// anything since some implicit conversion might kick in
693+
// during adaptation.
694+
true
695+
}
696+
697+
/** Try to add additional constraints on `underlyingVar`
698+
* to allow a selection based on the candidates found
699+
* in one of its own bound.
700+
*
701+
* @param isUpper If true, look for candidates in the upper bound,
702+
* otherwise look in the lower bound.
703+
*/
704+
def constrainInBound(isUpper: Boolean): Boolean = {
705+
// FIXME: we just stop after finding a matching candidate, should we
706+
// take the union of the constraints they add instead?
707+
candidatesInBound(isUpper).exists(constrainTo(_, isUpper))
708+
}
709+
710+
// FIXME: We currently only look at the lower bound if we don't find a
711+
// matching member in the upper bound, but that could exclude
712+
// the right candidate.
713+
constrainInBound(isUpper = true) || constrainInBound(isUpper = false)
714+
}
715+
519716
def typedSelect(tree: untpd.Select, pt: Type, qual: Tree)(using Context): Tree = qual match {
520717
case qual @ IntegratedTypeArgs(app) =>
521718
pt.revealIgnored match {
522719
case _: PolyProto => qual // keep the IntegratedTypeArgs to strip at next typedTypeApply
523720
case _ => app
524721
}
525722
case qual =>
723+
qual.tpe.widenDealias.stripTypeVar match {
724+
case tp: TypeParamRef =>
725+
ctx.typerState.constraint.typeVarOfParam(tp) match {
726+
case tvar: TypeVar =>
727+
constrainSelectionQualifier(qual.tpe, tree.name, tvar, pt)
728+
case _ =>
729+
}
730+
case _ =>
731+
}
732+
526733
val select = assignType(cpy.Select(tree)(qual, tree.name), qual)
527734
val select1 = toNotNullTermRef(select, pt)
528735

@@ -1096,6 +1303,21 @@ class Typer extends Namer
10961303
case _ =>
10971304
}
10981305

1306+
// The set of type variables in the prototype which appear only in covariant or
1307+
// contravariant positions. These should be instantiatable without
1308+
// preventing the body of the lambda from typechecking (...except in situations
1309+
// like `def foo[T, U <: T](x: T => U)`, where instantiating `T` to a specific
1310+
// type might overconstrain `U`).
1311+
//
1312+
// This doesn't exclude type variables which appear with a different
1313+
// variance at a later point in the same method call, or a subsequent chained
1314+
// call.
1315+
//
1316+
// TODO: try to replace this by an empty list and see how that affects
1317+
// inference and performance (we would end up creating a lot more type
1318+
// variables in `typedSelect`).
1319+
lazy val protoVariantVars = variances(pt).toList.filter(_._2 != 0).map(_._1)
1320+
10991321
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
11001322

11011323
/** The inferred parameter type for a parameter in a lambda that does
@@ -1118,7 +1340,7 @@ class Typer extends Namer
11181340
* If all attempts fail, issue a "missing parameter type" error.
11191341
*/
11201342
def inferredParamType(param: untpd.ValDef, formal: Type): Type =
1121-
if isFullyDefined(formal, ForceDegree.failBottom) then return formal
1343+
if isFullyDefined(formal, ForceDegree.none) then return formal
11221344
val target = calleeType.widen match
11231345
case mtpe: MethodType =>
11241346
val pos = paramIndex(param.name)
@@ -1128,9 +1350,17 @@ class Typer extends Namer
11281350
else NoType
11291351
case _ => NoType
11301352
if target.exists then formal <:< target
1131-
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1353+
// if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1354+
if isFullyDefined(formal, ForceDegree.none) then formal
11321355
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
1133-
else errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
1356+
else if !formal.isInstanceOf[WildcardType] then
1357+
instantiateSelected(formal, protoVariantVars)
1358+
// Intentionally leave uninstantiated type variables in the types of parameters,
1359+
// this works because `typedSelect` special cases the handling of qualifiers
1360+
// whose type is a type variable.
1361+
formal
1362+
else
1363+
errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
11341364

11351365
def protoFormal(i: Int): Type =
11361366
if (protoFormals.length == params.length) protoFormals(i)

0 commit comments

Comments
 (0)