Skip to content

Commit b9d8ca9

Browse files
committed
Reimplement constraint merging for correctness
The previous implementation simply combined the content of both constraints into one, but this is not enough since it meant that bounds were not propagated and so transitivity was violated. For example, when merging a constraint containing `?S <: ?T` and one containing `?T <: ?R`, the result did not verify `?S <:< ?R` (see the unit tests added in ConstraintsTest). The new implementation simply starts with one set of constraints and then adds the constraints from the other set one by one using `<:<` which takes care of propagating bounds. This is likely to be more expensive than the previous implementation but it turns out that `TyperState#mergeConstraintWith` is a rare operation (after the previous commit it is only called 43 times when compiling scala3-compiler), so the difference shouldn't be significant. This also incidentally fixes #12730 because the previous logic for checking if merging succeeded was flawed.
1 parent 3a082cb commit b9d8ca9

File tree

6 files changed

+128
-57
lines changed

6 files changed

+128
-57
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,6 @@ abstract class Constraint extends Showable {
152152
*/
153153
def uninstVars: collection.Seq[TypeVar]
154154

155-
/** The weakest constraint that subsumes both this constraint and `other`.
156-
* The constraints should be _compatible_, meaning that a type lambda
157-
* occurring in both constraints is associated with the same typevars in each.
158-
*
159-
* @param otherHasErrors If true, handle incompatible constraints by
160-
* returning an approximate constraint, instead of
161-
* failing with an exception
162-
*/
163-
def & (other: Constraint, otherHasErrors: Boolean)(using Context): Constraint
164-
165155
/** Whether `tl` is present in both `this` and `that` but is associated with
166156
* different TypeVars there, meaning that the constraints cannot be merged.
167157
*/

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -459,48 +459,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
459459

460460
// ----------- Joins -----------------------------------------------------
461461

462-
def & (other: Constraint, otherHasErrors: Boolean)(using Context): OrderingConstraint = {
463-
464-
def merge[T](m1: ArrayValuedMap[T], m2: ArrayValuedMap[T], join: (T, T) => T): ArrayValuedMap[T] = {
465-
var merged = m1
466-
def mergeArrays(xs1: Array[T], xs2: Array[T]) = {
467-
val xs = xs1.clone
468-
for (i <- xs.indices) xs(i) = join(xs1(i), xs2(i))
469-
xs
470-
}
471-
m2.foreachBinding { (poly, xs2) =>
472-
merged = merged.updated(poly,
473-
if (m1.contains(poly)) mergeArrays(m1(poly), xs2) else xs2)
474-
}
475-
merged
476-
}
477-
478-
def mergeParams(ps1: List[TypeParamRef], ps2: List[TypeParamRef]) =
479-
ps2.foldLeft(ps1)((ps1, p2) => if (ps1.contains(p2)) ps1 else p2 :: ps1)
480-
481-
// Must be symmetric
482-
def mergeEntries(e1: Type, e2: Type): Type =
483-
(e1, e2) match {
484-
case _ if e1 eq e2 => e1
485-
case (e1: TypeBounds, e2: TypeBounds) => e1 & e2
486-
case (e1: TypeBounds, _) if e1 contains e2 => e2
487-
case (_, e2: TypeBounds) if e2 contains e1 => e1
488-
case (tv1: TypeVar, tv2: TypeVar) if tv1 eq tv2 => e1
489-
case _ =>
490-
if (otherHasErrors)
491-
e1
492-
else
493-
throw new AssertionError(i"cannot merge $this with $other, mergeEntries($e1, $e2) failed")
494-
}
495-
496-
val that = other.asInstanceOf[OrderingConstraint]
497-
498-
new OrderingConstraint(
499-
merge(this.boundsMap, that.boundsMap, mergeEntries),
500-
merge(this.lowerMap, that.lowerMap, mergeParams),
501-
merge(this.upperMap, that.upperMap, mergeParams))
502-
}.showing(i"constraint merge $this with $other = $result", constr)
503-
504462
def hasConflictingTypeVarsFor(tl: TypeLambda, that: Constraint): Boolean =
505463
contains(tl) && that.contains(tl) &&
506464
// Since TypeVars are allocated in bulk for each type lambda, we only have

compiler/src/dotty/tools/dotc/core/TyperState.scala

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,43 @@ class TyperState() {
187187
*/
188188
def mergeConstraintWith(that: TyperState)(using Context): Unit =
189189
that.ensureNotConflicting(constraint)
190-
constraint = constraint & (that.constraint, otherHasErrors = that.reporter.errorsReported)
191-
for tvar <- constraint.uninstVars do
192-
if !isOwnedAnywhere(this, tvar) then includeVar(tvar)
190+
191+
val comparingCtx =
192+
if ctx.typerState == this then ctx
193+
else ctx.fresh.setTyperState(this)
194+
195+
comparing(typeComparer =>
196+
val other = that.constraint
197+
val res = other.domainLambdas.forall(tl =>
198+
// Integrate the type lambdas from `other`
199+
constraint.contains(tl) || other.isRemovable(tl) || {
200+
val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv }
201+
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
202+
typeComparer.addToConstraint(tl, tvars)
203+
}) &&
204+
// Integrate the additional constraints on type variables from `other`
205+
constraint.uninstVars.forall(tv =>
206+
val p = tv.origin
207+
val otherLos = other.lower(p)
208+
val otherHis = other.upper(p)
209+
val otherEntry = other.entry(p)
210+
( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) &&
211+
( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) &&
212+
((otherEntry eq constraint.entry(p)) || otherEntry.match
213+
case NoType =>
214+
true
215+
case tp: TypeBounds =>
216+
tp.contains(tv)
217+
case tp =>
218+
tv =:= tp
219+
)
220+
)
221+
assert(res || ctx.reporter.errorsReported, i"cannot merge $constraint with $other.")
222+
)(using comparingCtx)
223+
193224
for tl <- constraint.domainLambdas do
194225
if constraint.isRemovable(tl) then constraint = constraint.remove(tl)
226+
end mergeConstraintWith
195227

196228
/** Take ownership of `tvar`.
197229
*

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ object ProtoTypes {
376376
def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(using Context): List[Tree] =
377377
if state.typedArgs.size == args.length then state.typedArgs
378378
else
379+
val passedCtx = ctx
379380
val passedTyperState = ctx.typerState
380381
inContext(protoCtx.withUncommittedTyperState) {
381382
val protoTyperState = ctx.typerState
@@ -409,8 +410,7 @@ object ProtoTypes {
409410
tvar.instantiate(fromBelow = false)
410411
case _ =>
411412
}
412-
413-
passedTyperState.mergeConstraintWith(protoTyperState)
413+
passedTyperState.mergeConstraintWith(protoTyperState)(using passedCtx)
414414
end if
415415
args1
416416
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package dotty.tools
2+
package dotc.core
3+
4+
import vulpix.TestConfiguration
5+
6+
import dotty.tools.dotc.core.Contexts.{*, given}
7+
import dotty.tools.dotc.core.Decorators.{*, given}
8+
import dotty.tools.dotc.core.Symbols.*
9+
import dotty.tools.dotc.core.Types.*
10+
import dotty.tools.dotc.typer.ProtoTypes.constrained
11+
12+
import org.junit.Test
13+
14+
import dotty.tools.DottyTest
15+
16+
class ConstraintsTest:
17+
18+
@Test def mergeParamsTransitivity: Unit =
19+
inCompilerContext(TestConfiguration.basicClasspath,
20+
scalaSources = "trait A { def foo[S, T, R]: Any }") {
21+
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
22+
val List(s, t, r) = tp.paramRefs
23+
24+
val innerCtx = ctx.fresh.setExploreTyperState()
25+
inContext(innerCtx) {
26+
s <:< t
27+
}
28+
29+
t <:< r
30+
31+
ctx.typerState.mergeConstraintWith(innerCtx.typerState)
32+
assert(s frozen_<:< r,
33+
i"Merging constraints `?S <: ?T` and `?T <: ?R` should result in `?S <:< ?R`: ${ctx.typerState.constraint}")
34+
}
35+
end mergeParamsTransitivity
36+
37+
@Test def mergeBoundsTransitivity: Unit =
38+
inCompilerContext(TestConfiguration.basicClasspath,
39+
scalaSources = "trait A { def foo[S, T]: Any }") {
40+
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
41+
val List(s, t) = tp.paramRefs
42+
43+
val innerCtx = ctx.fresh.setExploreTyperState()
44+
inContext(innerCtx) {
45+
s <:< t
46+
}
47+
48+
defn.IntType <:< s
49+
50+
ctx.typerState.mergeConstraintWith(innerCtx.typerState)
51+
assert(defn.IntType frozen_<:< t,
52+
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
53+
}
54+
end mergeBoundsTransitivity

tests/pos/i12730.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
class ComponentSimple
2+
3+
class Props {
4+
def apply(props: Any): Any = ???
5+
}
6+
7+
class Foo[C] {
8+
def build: ComponentSimple = ???
9+
}
10+
11+
class Bar[E] {
12+
def render(r: E => Any): Unit = {}
13+
}
14+
15+
trait Conv[A, B] {
16+
def apply(a: A): B
17+
}
18+
19+
object Test {
20+
def toComponentCtor[F](c: ComponentSimple): Props = ???
21+
22+
def defaultToNoBackend[G, H](ev: G => Foo[H]): Conv[Foo[H], Bar[H]] = ???
23+
24+
def conforms[A]: A => A = ???
25+
26+
def problem = Main // crashes
27+
28+
def foo[H]: Foo[H] = ???
29+
30+
val NameChanger =
31+
foo
32+
.build
33+
34+
val Main =
35+
defaultToNoBackend(conforms).apply(foo)
36+
.render(_ => toComponentCtor(NameChanger)(13))
37+
}

0 commit comments

Comments
 (0)