Skip to content

Commit 563fab9

Browse files
authored
Improve type inference for functions like fold (#18780)
When calling a fold with an accumulator like `Nil` or `List()` one used to have add an explicit type ascription. This is now no longer necessary. When instantiating type variables that occur invariantly in the expected type of a lambda, we now replace covariant occurrences of `Nothing` in the (possibly widened) instance type of the type variable with fresh type variables. In the case of fold, the accumulator determines the instance type of a type variable that appears both in the parameter list and in the result type of the closure, which makes it invariant. So the accumulator type is improved in the way described above. The idea is that a fresh type variable in such places is always better than Nothing. For module values such as `Nil` we widen to `List[<fresh var>]`. This does possibly cause a new type error if the fold really wanted a `Nil` instance. But that case seems very rare, so it looks like a good bet in general to do the widening.
2 parents e4ba788 + 6f1a09a commit 563fab9

File tree

8 files changed

+177
-39
lines changed

8 files changed

+177
-39
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+3-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Flags.*
1010
import config.Config
1111
import config.Printers.typr
1212
import typer.ProtoTypes.{newTypeVar, representedParamRef}
13+
import transform.TypeUtils.isTransparent
1314
import UnificationDirection.*
1415
import NameKinds.AvoidNameKind
1516
import util.SimpleIdentitySet
@@ -566,13 +567,6 @@ trait ConstraintHandling {
566567
inst
567568
end approximation
568569

569-
private def isTransparent(tp: Type, traitOnly: Boolean)(using Context): Boolean = tp match
570-
case AndType(tp1, tp2) =>
571-
isTransparent(tp1, traitOnly) && isTransparent(tp2, traitOnly)
572-
case _ =>
573-
val cls = tp.underlyingClassRef(refinementOK = false).typeSymbol
574-
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
575-
576570
/** If `tp` is an intersection such that some operands are transparent trait instances
577571
* and others are not, replace as many transparent trait instances as possible with Any
578572
* as long as the result is still a subtype of `bound`. But fall back to the
@@ -585,7 +579,7 @@ trait ConstraintHandling {
585579
var dropped: List[Type] = List() // the types dropped so far, last one on top
586580

587581
def dropOneTransparentTrait(tp: Type): Type =
588-
if isTransparent(tp, traitOnly = true) && !kept.contains(tp) then
582+
if tp.isTransparent(traitOnly = true) && !kept.contains(tp) then
589583
dropped = tp :: dropped
590584
defn.AnyType
591585
else tp match
@@ -658,7 +652,7 @@ trait ConstraintHandling {
658652
def widenOr(tp: Type) =
659653
if widenUnions then
660654
val tpw = tp.widenUnion
661-
if (tpw ne tp) && !isTransparent(tpw, traitOnly = false) && (tpw <:< bound) then tpw else tp
655+
if (tpw ne tp) && !tpw.isTransparent() && (tpw <:< bound) then tpw else tp
662656
else tp.hardenUnions
663657

664658
def widenSingle(tp: Type) =

compiler/src/dotty/tools/dotc/core/Types.scala

+9-4
Original file line numberDiff line numberDiff line change
@@ -4908,6 +4908,9 @@ object Types {
49084908
tp
49094909
}
49104910

4911+
def typeToInstantiateWith(fromBelow: Boolean)(using Context): Type =
4912+
TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
4913+
49114914
/** Instantiate variable from the constraints over its `origin`.
49124915
* If `fromBelow` is true, the variable is instantiated to the lub
49134916
* of its lower bounds in the current constraint; otherwise it is
@@ -4916,7 +4919,7 @@ object Types {
49164919
* is also a singleton type.
49174920
*/
49184921
def instantiate(fromBelow: Boolean)(using Context): Type =
4919-
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
4922+
val tp = typeToInstantiateWith(fromBelow)
49204923
if myInst.exists then // The line above might have triggered instantiation of the current type variable
49214924
myInst
49224925
else
@@ -5812,11 +5815,13 @@ object Types {
58125815
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
58135816
tp.derivedLambdaType(tp.paramNames, formals, restpe)
58145817

5818+
protected def mapArg(arg: Type, tparam: ParamInfo): Type = arg match
5819+
case arg: TypeBounds => this(arg)
5820+
case arg => atVariance(variance * tparam.paramVarianceSign)(this(arg))
5821+
58155822
protected def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match
58165823
case arg :: otherArgs if tparams.nonEmpty =>
5817-
val arg1 = arg match
5818-
case arg: TypeBounds => this(arg)
5819-
case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg))
5824+
val arg1 = mapArg(arg, tparams.head)
58205825
val otherArgs1 = mapArgs(otherArgs, tparams.tail)
58215826
if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args
58225827
else arg1 :: otherArgs1

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

+11-5
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@ package transform
44

55
import core.*
66
import TypeErasure.ErasedValueType
7-
import Types.*
8-
import Contexts.*
9-
import Symbols.*
7+
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
108
import Names.Name
119

12-
import dotty.tools.dotc.core.Decorators.*
13-
1410
object TypeUtils {
1511
/** A decorator that provides methods on types
1612
* that are needed in the transformer pipeline.
@@ -98,5 +94,15 @@ object TypeUtils {
9894
def takesImplicitParams(using Context): Boolean = self.stripPoly match
9995
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
10096
case _ => false
97+
98+
/** Is this a type deriving only from transparent classes?
99+
* @param traitOnly if true, all class symbols must be transparent traits
100+
*/
101+
def isTransparent(traitOnly: Boolean = false)(using Context): Boolean = self match
102+
case AndType(tp1, tp2) =>
103+
tp1.isTransparent(traitOnly) && tp2.isTransparent(traitOnly)
104+
case _ =>
105+
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
106+
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
101107
}
102108
}

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+91-20
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import ProtoTypes.*
99
import NameKinds.UniqueName
1010
import util.Spans.*
1111
import util.{Stats, SimpleIdentityMap, SimpleIdentitySet, SrcPos}
12-
import Decorators.*
12+
import transform.TypeUtils.isTransparent
13+
import Decorators._
1314
import config.Printers.{gadts, typr}
1415
import annotation.tailrec
1516
import reporting.*
@@ -60,7 +61,9 @@ object Inferencing {
6061
def instantiateSelected(tp: Type, tvars: List[Type])(using Context): Unit =
6162
if (tvars.nonEmpty)
6263
IsFullyDefinedAccumulator(
63-
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
64+
new ForceDegree.Value(IfBottom.flip):
65+
override def appliesTo(tvar: TypeVar) = tvars.contains(tvar),
66+
minimizeSelected = true
6467
).process(tp)
6568

6669
/** Instantiate any type variables in `tp` whose bounds contain a reference to
@@ -154,15 +157,66 @@ object Inferencing {
154157
* their lower bound. Record whether successful.
155158
* 2nd Phase: If first phase was successful, instantiate all remaining type variables
156159
* to their upper bound.
160+
*
161+
* Instance types can be improved by replacing covariant occurrences of Nothing
162+
* with fresh type variables, if `force` allows this in its `canImprove` implementation.
157163
*/
158164
private class IsFullyDefinedAccumulator(force: ForceDegree.Value, minimizeSelected: Boolean = false)
159165
(using Context) extends TypeAccumulator[Boolean] {
160166

161-
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
167+
/** Replace toplevel-covariant occurrences (i.e. covariant without double flips)
168+
* of Nothing by fresh type variables. Double-flips are not covered to be
169+
* conservative and save a bit of time on traversals; we could probably
170+
* generalize that if we see use cases.
171+
* For singleton types and references to module classes: try to
172+
* improve the widened type. For module classes, the widened type
173+
* is the intersection of all its non-transparent parent types.
174+
*/
175+
private def improve(tvar: TypeVar) = new TypeMap:
176+
def apply(t: Type) = trace(i"improve $t", show = true):
177+
def tryWidened(widened: Type): Type =
178+
val improved = apply(widened)
179+
if improved ne widened then improved else mapOver(t)
180+
if variance > 0 then
181+
t match
182+
case t: TypeRef =>
183+
if t.symbol == defn.NothingClass then
184+
newTypeVar(TypeBounds.empty, nestingLevel = tvar.nestingLevel)
185+
else if t.symbol.is(ModuleClass) then
186+
tryWidened(t.parents.filter(!_.isTransparent())
187+
.foldLeft(defn.AnyType: Type)(TypeComparer.andType(_, _)))
188+
else
189+
mapOver(t)
190+
case t: TermRef =>
191+
tryWidened(t.widen)
192+
case _ =>
193+
mapOver(t)
194+
else t
195+
196+
// Don't map Nothing arguments for higher-kinded types; we'd get the wrong kind */
197+
override def mapArg(arg: Type, tparam: ParamInfo): Type =
198+
if tparam.paramInfo.isLambdaSub then arg
199+
else super.mapArg(arg, tparam)
200+
end improve
201+
202+
/** Instantiate type variable with possibly improved computed instance type.
203+
* @return true if variable was instantiated with improved type, which
204+
* in this case should not be instantiated further, false otherwise.
205+
*/
206+
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Boolean =
207+
if fromBelow && force.canImprove(tvar) then
208+
val inst = tvar.typeToInstantiateWith(fromBelow = true)
209+
if apply(true, inst) then
210+
// need to recursively check before improving, since improving adds type vars
211+
// which should not be instantiated at this point
212+
val better = improve(tvar)(inst)
213+
if better <:< TypeComparer.fullUpperBound(tvar.origin) then
214+
typr.println(i"forced instantiation of invariant ${tvar.origin} = $inst, improved to $better")
215+
tvar.instantiateWith(better)
216+
return true
162217
val inst = tvar.instantiate(fromBelow)
163218
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
164-
inst
165-
}
219+
false
166220

167221
private var toMaximize: List[TypeVar] = Nil
168222

@@ -178,26 +232,27 @@ object Inferencing {
178232
&& ctx.typerState.constraint.contains(tvar)
179233
&& {
180234
var fail = false
235+
var skip = false
181236
val direction = instDirection(tvar.origin)
182237
if minimizeSelected then
183238
if direction <= 0 && tvar.hasLowerBound then
184-
instantiate(tvar, fromBelow = true)
239+
skip = instantiate(tvar, fromBelow = true)
185240
else if direction >= 0 && tvar.hasUpperBound then
186-
instantiate(tvar, fromBelow = false)
241+
skip = instantiate(tvar, fromBelow = false)
187242
// else hold off instantiating unbounded unconstrained variable
188243
else if direction != 0 then
189-
instantiate(tvar, fromBelow = direction < 0)
244+
skip = instantiate(tvar, fromBelow = direction < 0)
190245
else if variance >= 0 && tvar.hasLowerBound then
191-
instantiate(tvar, fromBelow = true)
246+
skip = instantiate(tvar, fromBelow = true)
192247
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
193248
&& force.ifBottom == IfBottom.ok
194249
then // if variance == 0, prefer upper bound if one is given
195-
instantiate(tvar, fromBelow = true)
250+
skip = instantiate(tvar, fromBelow = true)
196251
else if variance >= 0 && force.ifBottom == IfBottom.fail then
197252
fail = true
198253
else
199254
toMaximize = tvar :: toMaximize
200-
!fail && foldOver(x, tvar)
255+
!fail && (skip || foldOver(x, tvar))
201256
}
202257
case tp => foldOver(x, tp)
203258
}
@@ -467,7 +522,7 @@ object Inferencing {
467522
*
468523
* we want to instantiate U to x.type right away. No need to wait further.
469524
*/
470-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
525+
def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
471526
Stats.record("variances")
472527
val constraint = ctx.typerState.constraint
473528

@@ -769,14 +824,30 @@ trait Inferencing { this: Typer =>
769824
}
770825

771826
/** An enumeration controlling the degree of forcing in "is-fully-defined" checks. */
772-
@sharable object ForceDegree {
773-
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom):
774-
override def toString = s"ForceDegree.Value(.., $ifBottom)"
775-
val none: Value = new Value(_ => false, IfBottom.ok) { override def toString = "ForceDegree.none" }
776-
val all: Value = new Value(_ => true, IfBottom.ok) { override def toString = "ForceDegree.all" }
777-
val failBottom: Value = new Value(_ => true, IfBottom.fail) { override def toString = "ForceDegree.failBottom" }
778-
val flipBottom: Value = new Value(_ => true, IfBottom.flip) { override def toString = "ForceDegree.flipBottom" }
779-
}
827+
@sharable object ForceDegree:
828+
class Value(val ifBottom: IfBottom):
829+
830+
/** Does `tv` need to be instantiated? */
831+
def appliesTo(tv: TypeVar): Boolean = true
832+
833+
/** Should we try to improve the computed instance type by replacing bottom types
834+
* with fresh type variables?
835+
*/
836+
def canImprove(tv: TypeVar): Boolean = false
837+
838+
override def toString = s"ForceDegree.Value($ifBottom)"
839+
end Value
840+
841+
val none: Value = new Value(IfBottom.ok):
842+
override def appliesTo(tv: TypeVar) = false
843+
override def toString = "ForceDegree.none"
844+
val all: Value = new Value(IfBottom.ok):
845+
override def toString = "ForceDegree.all"
846+
val failBottom: Value = new Value(IfBottom.fail):
847+
override def toString = "ForceDegree.failBottom"
848+
val flipBottom: Value = new Value(IfBottom.flip):
849+
override def toString = "ForceDegree.flipBottom"
850+
end ForceDegree
780851

781852
enum IfBottom:
782853
case ok, fail, flip

compiler/src/dotty/tools/dotc/typer/Typer.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -1634,14 +1634,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16341634
case _ =>
16351635

16361636
if desugared.isEmpty then
1637+
val forceDegree =
1638+
if pt.isValueType then
1639+
// Allow variables that appear invariantly in `pt` to be improved by mapping
1640+
// bottom types in their instance types to fresh type variables
1641+
new ForceDegree.Value(IfBottom.fail):
1642+
val tvmap = variances(pt)
1643+
override def canImprove(tvar: TypeVar) =
1644+
tvmap.computedVariance(tvar) == (0: Integer)
1645+
else
1646+
ForceDegree.failBottom
1647+
16371648
val inferredParams: List[untpd.ValDef] =
16381649
for ((param, i) <- params.zipWithIndex) yield
16391650
if (!param.tpt.isEmpty) param
16401651
else
16411652
val (formalBounds, isErased) = protoFormal(i)
16421653
val formal = formalBounds.loBound
16431654
val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing
1644-
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
1655+
val knownFormal = isFullyDefined(formal, forceDegree)
16451656
// If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound,
16461657
// try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala)
16471658
val paramType =

tests/neg/foldinf-ill-kinded.check

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-- [E007] Type Mismatch Error: tests/neg/foldinf-ill-kinded.scala:9:16 -------------------------------------------------
2+
9 | ys.combine(x) // error
3+
| ^^^^^^^^^^^^^
4+
| Found: Foo[List]
5+
| Required: Foo[Nothing]
6+
|
7+
| longer explanation available when compiling with `-explain`

tests/neg/foldinf-ill-kinded.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Foo[+T[_]]:
2+
def combine[T1[x] >: T[x]](x: T1[Int]): Foo[T1] = new Foo
3+
object Foo:
4+
def empty: Foo[Nothing] = new Foo
5+
6+
object X:
7+
def test(xs: List[List[Int]]): Unit =
8+
xs.foldLeft(Foo.empty)((ys, x) =>
9+
ys.combine(x) // error
10+
)

tests/pos/folds.scala

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
object Test:
3+
extension [A](xs: List[A])
4+
def foldl[B](acc: B)(f: (A, B) => B): B = ???
5+
6+
val xs = List(1, 2, 3)
7+
8+
val _ = xs.foldl(List())((y, ys) => y :: ys)
9+
10+
val _ = xs.foldl(Nil)((y, ys) => y :: ys)
11+
12+
def partition[a](xs: List[a], pred: a => Boolean): Tuple2[List[a], List[a]] = {
13+
xs.foldRight/*[Tuple2[List[a], List[a]]]*/((List(), List())) {
14+
(x, p) => if (pred (x)) (x :: p._1, p._2) else (p._1, x :: p._2)
15+
}
16+
}
17+
18+
def snoc[A](xs: List[A], x: A) = x :: xs
19+
20+
def reverse[A](xs: List[A]) =
21+
xs.foldLeft(Nil)(snoc)
22+
23+
def reverse2[A](xs: List[A]) =
24+
xs.foldLeft(List())(snoc)
25+
26+
val ys: Seq[Int] = xs
27+
ys.foldLeft(Seq())((ys, y) => y +: ys)
28+
ys.foldLeft(Nil)((ys, y) => y +: ys)
29+
30+
def dup[A](xs: List[A]) =
31+
xs.foldRight(Nil)((x, xs) => x :: x :: xs)
32+
33+
def toSet[A](xs: Seq[A]) =
34+
xs.foldLeft(Set.empty)(_ + _)

0 commit comments

Comments
 (0)