Skip to content

Commit e842810

Browse files
authored
Avoid incorrect simplifications when updating bounds in the constraint (#16410)
When combining an old and a new bound, we use `Type#&`/`Type#|` which perform simplifications. This is usually fine, but if the new bounds refer to the parameter currently being updated, we can run into cyclic reasoning issues which make the simplifications invalid after the update. We already have logic for handling self-references in parameter bounds: `updateEntry` calls `ensureNonCyclic` which sanitizes the type, but at this point the simplifications have already occured. This commit simply moves the logic out of `updateEntry` so that we can sanitize the new bounds before simplification. More precisely, we rename `ensureNonCyclic` to `validBoundsFor` which calls `validBoundFor` (singular). Both are used to sanitize bounds where needed in `addOneBound` and `unify`. Since all calls to `updateEntry` now have sanitized bounds, we no longer need to sanitize them in `updateEntry` itself, we document this change by adding a pre-condition to `updateEntry`. For the record, here's how `ConstraintsTest#validBoundsInit` used to fail. It defines a method: def foo[S >: T <: T | Int, T <: String]: Any Before this commit, when `foo` was added to the current constraints, the constraint `S <: T | Int` was propagated to the lower bound `T` of `S`. The updated upper bound of `T` was thus set to: String & (T | Int) But because `Type#&` performs simplifications, this became T | (String & Int) by relying on the fact that at this point, `T <: String`. But in fact this simplified bound no longer ensures that `T <: String`! The self-reference was then replaced by `Any` in `OrderingConstraint#ensureNonCyclic`. After this commit, the problematic simplification no longer occurs since the new `T | Int` is sanitized to `Any` before being intersected with the old bound.
2 parents 845105a + 50eb0e9 commit e842810

File tree

4 files changed

+71
-35
lines changed

4 files changed

+71
-35
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ abstract class Constraint extends Showable {
8888
* - Another type, indicating a solution for the parameter
8989
*
9090
* @pre `this contains param`.
91+
* @pre `tp` does not contain top-level references to `param`
92+
* (see `validBoundsFor`)
9193
*/
9294
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This
9395

@@ -172,6 +174,23 @@ abstract class Constraint extends Showable {
172174
*/
173175
def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean
174176

177+
/** Sanitize `bound` to make it either a valid upper or lower bound for
178+
* `param` depending on `isUpper`.
179+
*
180+
* Toplevel references to `param`, are replaced by `Any` if `isUpper` is true
181+
* and `Nothing` otherwise.
182+
*
183+
* @see `occursAtTopLevel` for a definition of "toplevel"
184+
* @see `validBoundsFor` to sanitize both the lower and upper bound at once.
185+
*/
186+
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type
187+
188+
/** Sanitize `bounds` to make them valid constraints for `param`.
189+
*
190+
* @see `validBoundFor` for details.
191+
*/
192+
def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type
193+
175194
/** A string that shows the reverse dependencies maintained by this constraint
176195
* (coDeps and contraDeps for OrderingConstraints).
177196
*/

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ trait ConstraintHandling {
257257
end LevelAvoidMap
258258

259259
/** Approximate `rawBound` if needed to make it a legal bound of `param` by
260-
* avoiding wildcards and types with a level strictly greater than its
260+
* avoiding cycles, wildcards and types with a level strictly greater than its
261261
* `nestingLevel`.
262262
*
263263
* Note that level-checking must be performed here and cannot be delayed
@@ -283,7 +283,7 @@ trait ConstraintHandling {
283283
// This is necessary for i8900-unflip.scala to typecheck.
284284
val v = if necessaryConstraintsOnly then -this.variance else this.variance
285285
atVariance(v)(super.legalVar(tp))
286-
approx(rawBound)
286+
constraint.validBoundFor(param, approx(rawBound), isUpper)
287287
end legalBound
288288

289289
protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
@@ -413,8 +413,10 @@ trait ConstraintHandling {
413413

414414
constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)
415415

416-
val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
417-
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
416+
val boundKept = constraint.validBoundsFor(pKept,
417+
constraint.nonParamBounds( pKept).substParam(pRemoved, pKept).bounds)
418+
var boundRemoved = constraint.validBoundsFor(pKept,
419+
constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept).bounds)
418420

419421
if level1 != level2 then
420422
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -525,20 +525,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
525525

526526
// ---------- Updates ------------------------------------------------------------
527527

528-
/** If `inst` is a TypeBounds, make sure it does not contain toplevel
529-
* references to `param` (see `Constraint#occursAtToplevel` for a definition
530-
* of "toplevel").
531-
* Any such references are replaced by `Nothing` in the lower bound and `Any`
532-
* in the upper bound.
533-
* References can be direct or indirect through instantiations of other
534-
* parameters in the constraint.
535-
*/
536-
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =
537-
538-
def recur(tp: Type, fromBelow: Boolean): Type = tp match
528+
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type =
529+
def recur(tp: Type): Type = tp match
539530
case tp: AndOrType =>
540-
val r1 = recur(tp.tp1, fromBelow)
541-
val r2 = recur(tp.tp2, fromBelow)
531+
val r1 = recur(tp.tp1)
532+
val r2 = recur(tp.tp2)
542533
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
543534
else tp.match
544535
case tp: OrType =>
@@ -547,35 +538,34 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
547538
r1 & r2
548539
case tp: TypeParamRef =>
549540
if tp eq param then
550-
if fromBelow then defn.NothingType else defn.AnyType
541+
if isUpper then defn.AnyType else defn.NothingType
551542
else entry(tp) match
552543
case NoType => tp
553-
case TypeBounds(lo, hi) => if lo eq hi then recur(lo, fromBelow) else tp
554-
case inst => recur(inst, fromBelow)
544+
case TypeBounds(lo, hi) => if lo eq hi then recur(lo) else tp
545+
case inst => recur(inst)
555546
case tp: TypeVar =>
556-
val underlying1 = recur(tp.underlying, fromBelow)
547+
val underlying1 = recur(tp.underlying)
557548
if underlying1 ne tp.underlying then underlying1 else tp
558549
case CapturingType(parent, refs) =>
559-
val parent1 = recur(parent, fromBelow)
550+
val parent1 = recur(parent)
560551
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
561552
case tp: AnnotatedType =>
562-
val parent1 = recur(tp.parent, fromBelow)
553+
val parent1 = recur(tp.parent)
563554
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
564555
case _ =>
565556
val tp1 = tp.dealiasKeepAnnots
566557
if tp1 ne tp then
567-
val tp2 = recur(tp1, fromBelow)
558+
val tp2 = recur(tp1)
568559
if tp2 ne tp1 then tp2 else tp
569560
else tp
570561

571-
inst match
572-
case bounds: TypeBounds =>
573-
bounds.derivedTypeBounds(
574-
recur(bounds.lo, fromBelow = true),
575-
recur(bounds.hi, fromBelow = false))
576-
case _ =>
577-
inst
578-
end ensureNonCyclic
562+
recur(bound)
563+
end validBoundFor
564+
565+
def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type =
566+
bounds.derivedTypeBounds(
567+
validBoundFor(param, bounds.lo, isUpper = false),
568+
validBoundFor(param, bounds.hi, isUpper = true))
579569

580570
/** Add the fact `param1 <: param2` to the constraint `current` and propagate
581571
* `<:<` relationships between parameters ("edges") but not bounds.
@@ -658,9 +648,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
658648
current1
659649
}
660650

661-
/** The public version of `updateEntry`. Guarantees that there are no cycles */
662651
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This =
663-
updateEntry(this, param, ensureNonCyclic(param, tp)).checkWellFormed()
652+
updateEntry(this, param, tp).checkWellFormed()
664653

665654
def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This =
666655
order(this, param1, param2, direction).checkWellFormed()
@@ -703,7 +692,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
703692

704693
def replaceParamIn(other: TypeParamRef) =
705694
val oldEntry = current.entry(other)
706-
val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement))
695+
val newEntry = oldEntry.substParam(param, replacement) match
696+
case tp: TypeBounds => validBoundsFor(other, tp)
697+
case tp => tp
707698
current = boundsLens.update(this, current, other, newEntry)
708699
var oldDepEntry = oldEntry
709700
var newDepEntry = newEntry

compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,27 @@ class ConstraintsTest:
5353
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
5454
}
5555
end mergeBoundsTransitivity
56+
57+
@Test def validBoundsInit: Unit = inCompilerContext(
58+
TestConfiguration.basicClasspath,
59+
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") {
60+
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
61+
val List(s, t) = tvars.tpes
62+
63+
val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
64+
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
65+
assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any
66+
}
67+
68+
@Test def validBoundsUnify: Unit = inCompilerContext(
69+
TestConfiguration.basicClasspath,
70+
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") {
71+
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
72+
val List(s, t) = tvars.tpes
73+
74+
s <:< t
75+
76+
val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
77+
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
78+
assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}")
79+
}

0 commit comments

Comments
 (0)