Skip to content

Commit bf03086

Browse files
authored
Merge pull request #15632 from dotty-staging/fix-14770
Instantiate more type variables to hard unions
2 parents 398b72e + dfcfb6b commit bf03086

12 files changed

+140
-58
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ abstract class Constraint extends Showable {
126126
*/
127127
def subst(from: TypeLambda, to: TypeLambda)(using Context): This
128128

129+
/** Is `tv` marked as hard in the constraint? */
130+
def isHard(tv: TypeVar): Boolean
131+
132+
/** The same as this constraint, but with `tv` marked as hard. */
133+
def withHard(tv: TypeVar)(using Context): This
134+
129135
/** Gives for each instantiated type var that does not yet have its `inst` field
130136
* set, the instance value stored in the constraint. Storing instances in constraints
131137
* is done only in a temporary way for contexts that may be retracted

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import typer.ProtoTypes.{newTypeVar, representedParamRef}
1313
import UnificationDirection.*
1414
import NameKinds.AvoidNameKind
1515
import util.SimpleIdentitySet
16+
import NullOpsDecorator.stripNull
1617

1718
/** Methods for adding constraints and solving them.
1819
*
@@ -613,8 +614,11 @@ trait ConstraintHandling {
613614
* 1. If `inst` is a singleton type, or a union containing some singleton types,
614615
* widen (all) the singleton type(s), provided the result is a subtype of `bound`.
615616
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
616-
* 2. If `inst` is a union type, approximate the union type from above by an intersection
617-
* of all common base types, provided the result is a subtype of `bound`.
617+
* 2a. If `inst` is a union type and `widenUnions` is true, approximate the union type
618+
* from above by an intersection of all common base types, provided the result
619+
* is a subtype of `bound`.
620+
* 2b. If `inst` is a union type and `widenUnions` is false, turn it into a hard
621+
* union type (except for unions | Null, which are kept in the state they were).
618622
* 3. Widen some irreducible applications of higher-kinded types to wildcard arguments
619623
* (see @widenIrreducible).
620624
* 4. Drop transparent traits from intersections (see @dropTransparentTraits).
@@ -627,10 +631,12 @@ trait ConstraintHandling {
627631
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
628632
* as those could leak the annotation to users (see run/inferred-repeated-result).
629633
*/
630-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
634+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
631635
def widenOr(tp: Type) =
632-
val tpw = tp.widenUnion
633-
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
636+
if widenUnions then
637+
val tpw = tp.widenUnion
638+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
639+
else tp.hardenUnions
634640

635641
def widenSingle(tp: Type) =
636642
val tpw = tp.widenSingletons
@@ -650,6 +656,23 @@ trait ConstraintHandling {
650656
wideInst.dropRepeatedAnnot
651657
end widenInferred
652658

659+
/** Convert all toplevel union types in `tp` to hard unions */
660+
extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
661+
case tp: AndType =>
662+
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
663+
case tp: RefinedType =>
664+
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
665+
case tp: RecType =>
666+
tp.rebind(tp.parent.hardenUnions)
667+
case tp: HKTypeLambda =>
668+
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
669+
case tp: OrType =>
670+
val tp1 = tp.stripNull
671+
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
672+
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
673+
case _ =>
674+
tp
675+
653676
/** The instance type of `param` in the current constraint (which contains `param`).
654677
* If `fromBelow` is true, the instance type is the lub of the parameter's
655678
* lower bounds; otherwise it is the glb of its upper bounds. However,
@@ -658,18 +681,18 @@ trait ConstraintHandling {
658681
* The instance type is not allowed to contain references to types nested deeper
659682
* than `maxLevel`.
660683
*/
661-
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
684+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int)(using Context): Type = {
662685
val approx = approximation(param, fromBelow, maxLevel).simplified
663686
if fromBelow then
664-
val widened = widenInferred(approx, param)
687+
val widened = widenInferred(approx, param, widenUnions)
665688
// Widening can add extra constraints, in particular the widened type might
666689
// be a type variable which is now instantiated to `param`, and therefore
667690
// cannot be used as an instantiation of `param` without creating a loop.
668691
// If that happens, we run `instanceType` again to find a new instantation.
669692
// (we do not check for non-toplevel occurences: those should never occur
670693
// since `addOneBound` disallows recursive lower bounds).
671694
if constraint.occursAtToplevel(param, widened) then
672-
instanceType(param, fromBelow, maxLevel)
695+
instanceType(param, fromBelow, widenUnions, maxLevel)
673696
else
674697
widened
675698
else

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Decorators._
66
import Contexts._
77
import Types._
88
import Symbols._
9-
import util.SimpleIdentityMap
9+
import util.{SimpleIdentitySet, SimpleIdentityMap}
1010
import collection.mutable
1111
import printing._
1212

@@ -68,7 +68,7 @@ final class ProperGadtConstraint private(
6868
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
6969

7070
def this() = this(
71-
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty),
71+
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty),
7272
mapping = SimpleIdentityMap.empty,
7373
reverseMapping = SimpleIdentityMap.empty,
7474
wasConstrained = false

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package dotc
33
package core
44

55
import Types._, Contexts._, Symbols._, Decorators._, TypeApplications._
6-
import util.SimpleIdentityMap
6+
import util.{SimpleIdentitySet, SimpleIdentityMap}
77
import collection.mutable
88
import printing.Printer
99
import printing.Texts._
@@ -24,12 +24,14 @@ object OrderingConstraint {
2424
/** The type of `OrderingConstraint#lowerMap`, `OrderingConstraint#upperMap` */
2525
type ParamOrdering = ArrayValuedMap[List[TypeParamRef]]
2626

27-
/** A new constraint with given maps */
28-
private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint =
27+
/** A new constraint with given maps and given set of hard typevars */
28+
private def newConstraint(
29+
boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering,
30+
hardVars: TypeVars)(using Context) : OrderingConstraint =
2931
if boundsMap.isEmpty && lowerMap.isEmpty && upperMap.isEmpty then
3032
empty
3133
else
32-
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap)
34+
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap, hardVars)
3335
if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size)
3436
result
3537

@@ -91,28 +93,28 @@ object OrderingConstraint {
9193
def entries(c: OrderingConstraint, poly: TypeLambda): Array[Type] | Null =
9294
c.boundsMap(poly)
9395
def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[Type])(using Context): OrderingConstraint =
94-
newConstraint(c.boundsMap.updated(poly, entries), c.lowerMap, c.upperMap)
96+
newConstraint(c.boundsMap.updated(poly, entries), c.lowerMap, c.upperMap, c.hardVars)
9597
def initial = NoType
9698
}
9799

98100
val lowerLens: ConstraintLens[List[TypeParamRef]] = new ConstraintLens[List[TypeParamRef]] {
99101
def entries(c: OrderingConstraint, poly: TypeLambda): Array[List[TypeParamRef]] | Null =
100102
c.lowerMap(poly)
101103
def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[List[TypeParamRef]])(using Context): OrderingConstraint =
102-
newConstraint(c.boundsMap, c.lowerMap.updated(poly, entries), c.upperMap)
104+
newConstraint(c.boundsMap, c.lowerMap.updated(poly, entries), c.upperMap, c.hardVars)
103105
def initial = Nil
104106
}
105107

106108
val upperLens: ConstraintLens[List[TypeParamRef]] = new ConstraintLens[List[TypeParamRef]] {
107109
def entries(c: OrderingConstraint, poly: TypeLambda): Array[List[TypeParamRef]] | Null =
108110
c.upperMap(poly)
109111
def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[List[TypeParamRef]])(using Context): OrderingConstraint =
110-
newConstraint(c.boundsMap, c.lowerMap, c.upperMap.updated(poly, entries))
112+
newConstraint(c.boundsMap, c.lowerMap, c.upperMap.updated(poly, entries), c.hardVars)
111113
def initial = Nil
112114
}
113115

114116
@sharable
115-
val empty = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty)
117+
val empty = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty)
116118
}
117119

118120
import OrderingConstraint._
@@ -134,10 +136,13 @@ import OrderingConstraint._
134136
* @param upperMap a map from TypeLambdas to arrays. Each array entry corresponds
135137
* to a parameter P of the type lambda; it contains all constrained parameters
136138
* Q that are known to be greater than P, i.e. P <: Q.
139+
* @param hardVars a set of type variables that are marked as hard and therefore will not
140+
* undergo a `widenUnion` when instantiated to their lower bound.
137141
*/
138142
class OrderingConstraint(private val boundsMap: ParamBounds,
139143
private val lowerMap : ParamOrdering,
140-
private val upperMap : ParamOrdering) extends Constraint {
144+
private val upperMap : ParamOrdering,
145+
private val hardVars : TypeVars) extends Constraint {
141146

142147
import UnificationDirection.*
143148

@@ -277,7 +282,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
277282
val entries1 = new Array[Type](nparams * 2)
278283
poly.paramInfos.copyToArray(entries1, 0)
279284
tvars.copyToArray(entries1, nparams)
280-
newConstraint(boundsMap.updated(poly, entries1), lowerMap, upperMap).init(poly)
285+
newConstraint(boundsMap.updated(poly, entries1), lowerMap, upperMap, hardVars).init(poly)
281286
}
282287

283288
/** Split dependent parameters off the bounds for parameters in `poly`.
@@ -478,7 +483,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
478483
}
479484
po.remove(pt).mapValuesNow(removeFromBoundss)
480485
}
481-
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap))
486+
val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param))
487+
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1)
482488
.checkNonCyclic()
483489
}
484490

@@ -505,7 +511,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
505511
def swapKey[T](m: ArrayValuedMap[T]) =
506512
val info = m(from)
507513
if info == null then m else m.remove(from).updated(to, info)
508-
var current = newConstraint(swapKey(boundsMap), swapKey(lowerMap), swapKey(upperMap))
514+
var current = newConstraint(swapKey(boundsMap), swapKey(lowerMap), swapKey(upperMap), hardVars)
509515
def subst[T <: Type](x: T): T = x.subst(from, to).asInstanceOf[T]
510516
current.foreachParam {(p, i) =>
511517
current = boundsLens.map(this, current, p, i, subst)
@@ -515,6 +521,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
515521
constr.println(i"renamed $this to $current")
516522
current.checkNonCyclic()
517523

524+
def isHard(tv: TypeVar) = hardVars.contains(tv)
525+
526+
def withHard(tv: TypeVar)(using Context) =
527+
newConstraint(boundsMap, lowerMap, upperMap, hardVars + tv)
528+
518529
def instType(tvar: TypeVar): Type = entry(tvar.origin) match
519530
case _: TypeBounds => NoType
520531
case tp: TypeParamRef => typeVarOfParam(tp).orElse(tp)

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -485,33 +485,42 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
485485
false
486486
}
487487

488-
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
489-
// before splitting the LHS into its constituents. That way, the RHS variables are
490-
// constraint by the hard union and can be instantiated to it. If we just split and add
491-
// the two parts of the LHS separately to the constraint, the lower bound would become
492-
// a soft union.
493-
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
494-
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
495-
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
496-
case _ => true
488+
/** Mark toplevel type vars in `tp2` as hard in the current constraint */
489+
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
490+
case tvar: TypeVar if constraint.contains(tvar.origin) =>
491+
constraint = constraint.withHard(tvar)
492+
case tp2: TypeParamRef if constraint.contains(tp2) =>
493+
hardenTypeVars(constraint.typeVarOfParam(tp2))
494+
case tp2: AndOrType =>
495+
hardenTypeVars(tp2.tp1)
496+
hardenTypeVars(tp2.tp2)
497+
case _ =>
497498

498-
widenOK
499-
|| joinOK
500-
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
501-
|| containsAnd(tp1)
502-
&& !joined
503-
&& {
504-
joined = true
505-
try inFrozenGadt(recur(tp1.join, tp2))
506-
finally joined = false
507-
}
508-
// An & on the left side loses information. We compensate by also trying the join.
509-
// This is less ad-hoc than it looks since we produce joins in type inference,
510-
// and then need to check that they are indeed supertypes of the original types
511-
// under -Ycheck. Test case is i7965.scala.
512-
// On the other hand, we could get a combinatorial explosion by applying such joins
513-
// recursively, so we do it only once. See i14870.scala as a test case, which would
514-
// loop for a very long time without the recursion brake.
499+
val res = widenOK || joinOK
500+
|| recur(tp11, tp2) && recur(tp12, tp2)
501+
|| containsAnd(tp1)
502+
&& !joined
503+
&& {
504+
joined = true
505+
try inFrozenGadt(recur(tp1.join, tp2))
506+
finally joined = false
507+
}
508+
// An & on the left side loses information. We compensate by also trying the join.
509+
// This is less ad-hoc than it looks since we produce joins in type inference,
510+
// and then need to check that they are indeed supertypes of the original types
511+
// under -Ycheck. Test case is i7965.scala.
512+
// On the other hand, we could get a combinatorial explosion by applying such joins
513+
// recursively, so we do it only once. See i14870.scala as a test case, which would
514+
// loop for a very long time without the recursion brake.
515+
516+
if res && !tp1.isSoft && state.isCommittable then
517+
// We use a heuristic here where every toplevel type variable on the right hand side
518+
// is marked so that it converts all soft unions in its lower bound to hard unions
519+
// before it is instantiated. The reason is that the variable's instance type will
520+
// be a supertype of (decomposed and reconstituted) `tp1`.
521+
hardenTypeVars(tp2)
522+
523+
res
515524

516525
case CapturingType(parent1, refs1) =>
517526
if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK && sameBoxed(tp1, tp2, refs1)
@@ -2960,8 +2969,8 @@ object TypeComparer {
29602969
def subtypeCheckInProgress(using Context): Boolean =
29612970
comparing(_.subtypeCheckInProgress)
29622971

2963-
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
2964-
comparing(_.instanceType(param, fromBelow, maxLevel))
2972+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
2973+
comparing(_.instanceType(param, fromBelow, widenUnions, maxLevel))
29652974

29662975
def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
29672976
comparing(_.approximation(param, fromBelow, maxLevel))
@@ -2981,8 +2990,8 @@ object TypeComparer {
29812990
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
29822991
comparing(_.addToConstraint(tl, tvars))
29832992

2984-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
2985-
comparing(_.widenInferred(inst, bound))
2993+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
2994+
comparing(_.widenInferred(inst, bound, widenUnions))
29862995

29872996
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
29882997
comparing(_.dropTransparentTraits(tp, bound))

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,9 @@ object TypeOps:
537537
override def apply(tp: Type): Type = tp match
538538
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
539539
val lo = TypeComparer.instanceType(
540-
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
540+
tp.origin,
541+
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
542+
widenUnions = tp.widenUnions)(using mapCtx)
541543
val lo1 = apply(lo)
542544
if (lo1 ne lo) lo1 else tp
543545
case _ =>

compiler/src/dotty/tools/dotc/core/TyperState.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,14 @@ class TyperState() {
229229
constraint.contains(tl) || other.isRemovable(tl) || {
230230
val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv }
231231
if this.isCommittable then
232-
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
232+
tvars.foreach(tvar =>
233+
if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
233234
typeComparer.addToConstraint(tl, tvars)
234235
}) &&
235236
// Integrate the additional constraints on type variables from `other`
237+
// and merge hardness markers
236238
constraint.uninstVars.forall(tv =>
239+
if other.isHard(tv) then constraint = constraint.withHard(tv)
237240
val p = tv.origin
238241
val otherLos = other.lower(p)
239242
val otherHis = other.upper(p)

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4714,12 +4714,15 @@ object Types {
47144714
* is also a singleton type.
47154715
*/
47164716
def instantiate(fromBelow: Boolean)(using Context): Type =
4717-
val tp = TypeComparer.instanceType(origin, fromBelow, nestingLevel)
4717+
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
47184718
if myInst.exists then // The line above might have triggered instantiation of the current type variable
47194719
myInst
47204720
else
47214721
instantiateWith(tp)
47224722

4723+
/** Widen unions when instantiating this variable in the current context? */
4724+
def widenUnions(using Context): Boolean = !ctx.typerState.constraint.isHard(this)
4725+
47234726
/** For uninstantiated type variables: the entry in the constraint (either bounds or
47244727
* provisional instance value)
47254728
*/

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,7 @@ class Namer { typer: Typer =>
18881888
TypeOps.simplify(tp.widenTermRefExpr,
18891889
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
18901890
case ctp: ConstantType if sym.isInlineVal => ctp
1891-
case tp => TypeComparer.widenInferred(tp, pt)
1891+
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)
18921892

18931893
// Replace aliases to Unit by Unit itself. If we leave the alias in
18941894
// it would be erased to BoxedUnit.

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
514514
val tparams = poly.paramRefs
515515
val variances = childClass.typeParams.map(_.paramVarianceSign)
516516
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
517-
TypeComparer.instanceType(tparam, fromBelow = variance < 0)
517+
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true)
518518
)
519519
val instanceType = resType.substParams(poly, instanceTypes)
520520
// this is broken in tests/run/i13332intersection.scala,

0 commit comments

Comments
 (0)