Skip to content

Commit fe5be59

Browse files
committed
Avoid inference getting stuck when the expected type contains a union/intersection
When we type a method call, we infer constraints based on its expected type before typing its arguments. This way, we can type these arguments with a precise expected type. This works fine as long as the constraints we infer based on the expected type are _necessary_ constraints, but in general type inference can go further and infer _sufficient_ constraints, meaning that we might get stuck with a set of constraints which does not allow the method arguments to be typed at all. Since 8067b95 we work around the problem by simply not propagating any constraint when the expected type is a union, but this solution is incomplete: - It only handles unions at the top-level, but the same problem can happen with unions in any covariant position (method b of or-inf.scala) as well as intersections in contravariant positions (and-inf.scala, i8378.scala) - Even when a union appear at the top-level, there might be constraints we can propagate, for example if only one branch can possibly match (method c of or-inf.scala) Thankfully, we already have a solution that works for all these problems: `TypeComparer#either` is capable of inferring only necessary constraints. So far, this was only done when inferring GADT bounds to preserve soundness, this commit extends this to use the same logic when constraining a method based on its expected type (as determined by the ConstrainResult mode). Additionally, `ConstraintHandling#addConstraint` needs to also be taught to only keep necessary constraints under this mode. Fixes #8378 which I previously thought was unfixable :).
1 parent 0bd0fcf commit fe5be59

File tree

13 files changed

+151
-58
lines changed

13 files changed

+151
-58
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,10 @@ trait ConstraintHandling[AbstractContext] {
484484
* recording an isLess relationship instead (even though this is not implied
485485
* by the bound).
486486
*
487-
* Narrowing a constraint is better than widening it, because narrowing leads
488-
* to incompleteness (which we face anyway, see for instance eitherIsSubType)
489-
* but widening leads to unsoundness.
487+
* Normally, narrowing a constraint is better than widening it, because
488+
* narrowing leads to incompleteness (which we face anyway, see for
489+
* instance `TypeComparer#either`) but widening leads to unsoundness,
490+
* but note the special handling in `ConstrainResult` mode below.
490491
*
491492
* A test case that demonstrates the problem is i864.scala.
492493
* Turn Config.checkConstraintsSeparated on to get an accurate diagnostic
@@ -544,10 +545,23 @@ trait ConstraintHandling[AbstractContext] {
544545
case bound: TypeParamRef if constraint contains bound =>
545546
addParamBound(bound)
546547
case _ =>
548+
val savedConstraint = constraint
547549
val pbound = prune(bound)
548-
pbound.exists
549-
&& kindCompatible(param, pbound)
550-
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
550+
val constraintsNarrowed = constraint ne savedConstraint
551+
552+
val res =
553+
pbound.exists
554+
&& kindCompatible(param, pbound)
555+
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
556+
// If we're in `ConstrainResult` mode, we don't want to commit to a
557+
// set of constraints that would later prevent us from typechecking
558+
// arguments, so if `pruneParams` had to narrow the constraints, we
559+
// simply do not record any new constraint.
560+
// Unlike in `TypeComparer#either`, the same reasoning does not apply
561+
// to GADT mode because this code is never run on GADT constraints.
562+
if ctx.mode.is(Mode.ConstrainResult) && constraintsNarrowed then
563+
constraint = savedConstraint
564+
res
551565
}
552566
finally addConstraintInvocations -= 1
553567
}

compiler/src/dotty/tools/dotc/core/Mode.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ object Mode {
6060
*/
6161
val Printing: Mode = newMode(10, "Printing")
6262

63+
/** We are constraining a method based on its expected type. */
64+
val ConstrainResult: Mode = newMode(11, "ConstrainResult")
65+
6366
/** We are currently in a `viewExists` check. In that case, ambiguous
6467
* implicits checks are disabled and we succeed with the first implicit
6568
* found.

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,14 +1364,26 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
13641364

13651365
/** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints.
13661366
*
1367-
* If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints.
1368-
* If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all.
1367+
* If we're inferring GADT bounds or constraining a method based on its
1368+
* expected type, we infer only the _necessary_ constraints, this means we
1369+
* keep the smaller constraint if any, or no constraint at all. This is
1370+
* necessary for GADT bounds inference to be sound. When constraining a
1371+
* method, this avoid committing of constraints that would later prevent us
1372+
* from typechecking method arguments, see or-inf.scala and and-inf.scala for
1373+
* examples.
13691374
*
1375+
* Otherwise, we infer _sufficient_ constraints: we try to keep the smaller of
1376+
* the two constraints, but if never is smaller than the other, we just pick
1377+
* the first one.
1378+
*
1379+
* @see [[necessaryEither]] for the GADT / result type case
13701380
* @see [[sufficientEither]] for the normal case
1371-
* @see [[necessaryEither]] for the GADTFlexible case
13721381
*/
13731382
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
1374-
if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
1383+
if ctx.mode.is(Mode.GadtConstraintInference) || ctx.mode.is(Mode.ConstrainResult) then
1384+
necessaryEither(op1, op2)
1385+
else
1386+
sufficientEither(op1, op2)
13751387

13761388
/** Returns true iff the result of evaluating either `op1` or `op2` is true,
13771389
* trying at the same time to keep the constraint as wide as possible.
@@ -1438,8 +1450,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14381450
* T1 & T2 <:< T3
14391451
* T1 <:< T2 | T3
14401452
*
1441-
* Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT
1442-
* constraints that necessarily follow from the subtyping relationship. For instance, if we have
1453+
* Unlike [[sufficientEither]], this method is used in GADTConstraintInference mode, when we are attempting
1454+
* to infer GADT constraints that necessarily follow from the subtyping relationship. For instance, if we have
14431455
*
14441456
* enum Expr[T] {
14451457
* case IntExpr(i: Int) extends Expr[Int]
@@ -1466,48 +1478,49 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14661478
*
14671479
* then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive.
14681480
*
1481+
* This method is also used in ConstrainResult mode
1482+
* to avoid inference getting stuck due to lack of backtracking,
1483+
* see or-inf.scala and and-inf.scala for examples.
1484+
*
14691485
* Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both
14701486
* subtyping relationships.
14711487
*/
1472-
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = {
1488+
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean =
14731489
val preConstraint = constraint
1474-
14751490
val preGadt = ctx.gadt.fresh
1476-
// if GADTflexible mode is on, we expect to always have a ProperGadtConstraint
1477-
val pre = preGadt.asInstanceOf[ProperGadtConstraint]
1478-
if (op1) {
1479-
val leftConstraint = constraint
1480-
val leftGadt = ctx.gadt.fresh
1491+
1492+
def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
1493+
subsumes(left, right, preConstraint) && preGadt.match
1494+
case preGadt: ProperGadtConstraint =>
1495+
preGadt.subsumes(leftGadt, rightGadt, preGadt)
1496+
case _ =>
1497+
true
1498+
1499+
if op1 then
1500+
val op1Constraint = constraint
1501+
val op1Gadt = ctx.gadt.fresh
14811502
constraint = preConstraint
14821503
ctx.gadt.restore(preGadt)
1483-
if (op2)
1484-
if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) {
1485-
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt")
1486-
constr.println(i"CUT - prefer $constraint over $leftConstraint")
1487-
true
1488-
}
1489-
else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) {
1490-
gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}")
1491-
constr.println(i"CUT - prefer $leftConstraint over $constraint")
1492-
constraint = leftConstraint
1493-
ctx.gadt.restore(leftGadt)
1494-
true
1495-
}
1496-
else {
1504+
if op2 then
1505+
if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then
1506+
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt")
1507+
constr.println(i"CUT - prefer $constraint over $op1Constraint")
1508+
else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then
1509+
gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}")
1510+
constr.println(i"CUT - prefer $op1Constraint over $constraint")
1511+
constraint = op1Constraint
1512+
ctx.gadt.restore(op1Gadt)
1513+
else
14971514
gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt")
14981515
constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint")
14991516
constraint = preConstraint
15001517
ctx.gadt.restore(preGadt)
1501-
true
1502-
}
1503-
else {
1504-
constraint = leftConstraint
1505-
ctx.gadt.restore(leftGadt)
1506-
true
1507-
}
1508-
}
1518+
else
1519+
constraint = op1Constraint
1520+
ctx.gadt.restore(op1Gadt)
1521+
true
15091522
else op2
1510-
}
1523+
end necessaryEither
15111524

15121525
/** Does type `tp1` have a member with name `name` whose normalized type is a subtype of
15131526
* the normalized type of the refinement `tp2`?

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,14 @@ object ProtoTypes {
5959
else ctx.test(testCompat)
6060
}
6161

62-
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = pt.dealias match {
63-
case _: OrType => true
64-
// Don't constrain results with union types, since comparison with a union
65-
// type on the right might commit too early into one side.
66-
case pt => pt.isRef(defn.UnitClass)
67-
}
62+
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean =
63+
pt.dealias.isRef(defn.UnitClass)
6864

6965
/** Check that the result type of the current method
7066
* fits the given expected result type.
7167
*/
72-
def constrainResult(mt: Type, pt: Type)(implicit ctx: Context): Boolean = {
68+
def constrainResult(mt: Type, pt: Type)(implicit parentCtx: Context): Boolean = {
69+
given ctx as Context = parentCtx.addMode(Mode.ConstrainResult)
7370
val savedConstraint = ctx.typerState.constraint
7471
val res = pt.widenExpr match {
7572
case pt: FunProto =>

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CompilationTests extends ParallelTesting {
137137
compileFile("tests/neg-custom-args/i3882.scala", allowDeepSubtypes),
138138
compileFile("tests/neg-custom-args/i4372.scala", allowDeepSubtypes),
139139
compileFile("tests/neg-custom-args/i1754.scala", allowDeepSubtypes),
140+
compileFile("tests/neg-custom-args/interop-polytypes.scala", allowDeepSubtypes.and("-Yexplicit-nulls")),
140141
compileFile("tests/neg-custom-args/conditionalWarnings.scala", allowDeepSubtypes.and("-deprecation").and("-Xfatal-warnings")),
141142
compileFilesInDir("tests/neg-custom-args/isInstanceOf", allowDeepSubtypes and "-Xfatal-warnings"),
142143
compileFile("tests/neg-custom-args/i3627.scala", allowDeepSubtypes),

tests/neg/i6565.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ def (o: Lifted[O]) flatMap [O, U] (f: O => Lifted[U]): Lifted[U] = ???
99
val error: Err = Err()
1010

1111
lazy val ok: Lifted[String] = { // ok despite map returning a union
12-
point("a").map(_ => if true then "foo" else error) // error
12+
point("a").map(_ => if true then "foo" else error) // ok
1313
}
1414

1515
lazy val bad: Lifted[String] = { // found Lifted[Object]
1616
point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // error
17-
}
17+
}

tests/neg/union.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ object O {
1717

1818
val x: A = f(new A { }, new A)
1919

20-
val y1: A | B = f(new A { }, new B) // error
20+
val y1: A | B = f(new A { }, new B) // ok
2121
val y2: A | B = f[A | B](new A { }, new B) // ok
2222

2323
val z = if (???) new A{} else new B

tests/pos/and-inf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class A
2+
class B
3+
4+
class Inv[T]
5+
class Contra[-T]
6+
7+
class Test {
8+
def foo[T, S](x: T, y: S): Contra[Inv[T] & Inv[S]] = ???
9+
val a: A = new A
10+
val b: B = new B
11+
12+
val x: Contra[Inv[A] & Inv[B]] = foo(a, b)
13+
}

tests/pos/i7829.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
class X
2+
class Y
3+
4+
object Test {
5+
type Id[T] = T
6+
7+
val a: 1 = identity(1)
8+
val b: Id[1] = identity(1)
9+
10+
val c: X | Y = identity(if (true) new X else new Y)
11+
val d: Id[X | Y] = identity(if (true) new X else new Y)
12+
13+
def impUnion: Unit = {
14+
class Base
15+
class A extends Base
16+
class B extends Base
17+
class Inv[T]
18+
19+
implicit def invBase: Inv[Base] = new Inv[Base]
20+
21+
def getInv[T](x: T)(implicit inv: Inv[T]): Int = 1
22+
23+
val a: Int = getInv(if (true) new A else new B)
24+
// If we keep unions when doing the implicit search, this would give us: "no implicit argument of type Inv[X | Y]"
25+
val b: Int | Any = getInv(if (true) new A else new B)
26+
}
27+
}

tests/pos/i8378.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
trait Has[A]
2+
3+
trait A
4+
trait B
5+
trait C
6+
7+
trait ZLayer[-RIn, +E, +ROut]
8+
9+
object ZLayer {
10+
def fromServices[A0, A1, B](f: (A0, A1) => B): ZLayer[Has[A0] with Has[A1], Nothing, Has[B]] =
11+
???
12+
}
13+
14+
val live: ZLayer[Has[A] & Has[B], Nothing, Has[C]] =
15+
ZLayer.fromServices { (a: A, b: B) =>
16+
new C {}
17+
}

tests/pos/or-inf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
3+
def a(lis: Set[Int] | Set[String]) = {}
4+
a(Set(1))
5+
a(Set(""))
6+
7+
def b(lis: List[Set[Int] | Set[String]]) = {}
8+
b(List(Set(1)))
9+
b(List(Set("")))
10+
11+
def c(x: Set[Any] | Array[Any]) = {}
12+
c(Set(1))
13+
c(Array(1))
14+
}

tests/pos/orinf.scala

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)