Skip to content

Commit 81fba98

Browse files
committed
Avoid infinite loop in type variable instantiation
Rename `checkNonCyclic` to `occursAtToplevel` and refactor it to return a boolean, use it in `ConstraintHandling#instanceType` to make sure we do not introduce a cycle when instantiating a type variable. Some alternatives I considered: - Run `widenInferred` inside frozen constraints: this prevents `Set[A] | Set[Int]` to be widened to `Set[Int]` after instantiating `A := Int` - Run `widenInferred` with the upper bound of `param` instead of `param` itself as a bound: I think this is still not safe because the upper bound of `param` might recursively refer to `param`, it also breaks type inference of one expression in ZIO.
1 parent 59587ba commit 81fba98

File tree

4 files changed

+55
-21
lines changed

4 files changed

+55
-21
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

+6
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ abstract class Constraint extends Showable {
153153
/** Check that no constrained parameter contains itself as a bound */
154154
def checkNonCyclic()(implicit ctx: Context): this.type
155155

156+
/** Does `param` occur at the toplevel in `tp` ?
157+
* Toplevel means: the type itself or a factor in some
158+
* combination of `&` or `|` types.
159+
*/
160+
def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean
161+
156162
/** Check that constraint only refers to TypeParamRefs bound by itself */
157163
def checkClosed()(implicit ctx: Context): Unit
158164

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+15-2
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,21 @@ trait ConstraintHandling[AbstractContext] {
361361
* is also a singleton type.
362362
*/
363363
def instanceType(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
364-
val inst = approximation(param, fromBelow).simplified
365-
if (fromBelow) widenInferred(inst, param) else inst
364+
val approx = approximation(param, fromBelow).simplified
365+
if (fromBelow)
366+
val widened = widenInferred(approx, param)
367+
// Widening can add extra constraints, in particular the widened type might
368+
// be a type variable which is now instantiated to `param`, and therefore
369+
// cannot be used as an instantiation of `param` without creating a loop.
370+
// If that happens, we run `instanceType` again to find a new instantation.
371+
// (we do not check for non-toplevel occurences: those should never occur
372+
// since `addOneBound` disallows recursive lower bounds).
373+
if constraint.occursAtToplevel(param, widened) then
374+
instanceType(param, fromBelow)
375+
else
376+
widened
377+
else
378+
approx
366379
}
367380

368381
/** Constraint `c1` subsumes constraint `c2`, if under `c2` as constraint we have

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

+22-19
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
300300
// ---------- Updates ------------------------------------------------------------
301301

302302
/** If `inst` is a TypeBounds, make sure it does not contain toplevel
303-
* references to `param`. Toplevel means: the term itself or a factor in some
304-
* combination of `&` or `|` types.
303+
* references to `param` (see `Constraint#occursAtToplevel` for a definition
304+
* of "toplevel").
305305
* Any such references are replace by `Nothing` in the lower bound and `Any`
306306
* in the upper bound.
307307
* References can be direct or indirect through instantiations of other
@@ -594,33 +594,36 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
594594
// ---------- Checking -----------------------------------------------
595595

596596
def checkNonCyclic()(implicit ctx: Context): this.type =
597-
if Config.checkConstraintsNonCyclic then domainParams.foreach(checkNonCyclic)
597+
if Config.checkConstraintsNonCyclic then
598+
domainParams.foreach { param =>
599+
val inst = entry(param)
600+
assert(!isLess(param, param),
601+
s"cyclic ordering involving $param in ${this.show}, upper = $inst")
602+
assert(occursAtToplevel(param, inst),
603+
s"cyclic bound for $param: ${inst.show} in ${this.show}")
604+
}
598605
this
599606

600-
private def checkNonCyclic(param: TypeParamRef)(implicit ctx: Context): Unit =
601-
assert(!isLess(param, param), i"cyclic ordering involving $param in $this, upper = ${upper(param)}")
607+
def occursAtToplevel(param: TypeParamRef, inst: Type)(implicit ctx: Context): Boolean =
602608

603-
def recur(tp: Type)(using Context): Unit = tp match
609+
def occurs(tp: Type)(using Context): Boolean = tp match
604610
case tp: AndOrType =>
605-
recur(tp.tp1)
606-
recur(tp.tp2)
611+
occurs(tp.tp1) || occurs(tp.tp2)
607612
case tp: TypeParamRef =>
608-
assert(tp ne param, i"cyclic bound for $param: ${entry(param)} in $this")
609-
entry(tp) match
610-
case NoType =>
611-
case TypeBounds(lo, hi) => if lo eq hi then recur(lo)
612-
case inst => recur(inst)
613+
(tp eq param) || entry(tp).match
614+
case NoType => false
615+
case TypeBounds(lo, hi) => (lo eq hi) && occurs(lo)
616+
case inst => occurs(inst)
613617
case tp: TypeVar =>
614-
recur(tp.underlying)
618+
occurs(tp.underlying)
615619
case TypeBounds(lo, hi) =>
616-
recur(lo)
617-
recur(hi)
620+
occurs(lo) || occurs(hi)
618621
case _ =>
619622
val tp1 = tp.dealias
620-
if tp1 ne tp then recur(tp1)
623+
(tp1 ne tp) && occurs(tp1)
621624

622-
recur(entry(param))
623-
end checkNonCyclic
625+
occurs(inst)
626+
end occursAtToplevel
624627

625628
override def checkClosed()(using Context): Unit =
626629

tests/neg/widenInst-cycle.scala

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import scala.reflect.ClassTag
2+
3+
class Test {
4+
def foo[N >: C | D <: C, C, D](implicit ct: ClassTag[N]): Unit = {}
5+
// This used to lead to an infinite loop, because:
6+
// widenInferred(?C | ?D, ?N)
7+
// returns ?C, with the following extra constraints:
8+
// ?C := ?N
9+
// ?D := ?N
10+
// So we ended up trying to instantiate ?N with ?N.
11+
foo // error
12+
}

0 commit comments

Comments
 (0)