Skip to content

Commit d490d13

Browse files
authored
Help implement Metals' infer expected type feature (#21390)
2 parents afcb0ad + 43fc10c commit d490d13

File tree

15 files changed

+666
-61
lines changed

15 files changed

+666
-61
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+8-4
Original file line numberDiff line numberDiff line change
@@ -3268,9 +3268,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
32683268

32693269
/** The trace of comparison operations when performing `op` */
32703270
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean)(using Context): String =
3271-
val cmp = explainingTypeComparer(short)
3272-
inSubComparer(cmp)(op)
3273-
cmp.lastTrace(header)
3271+
explaining(cmp => { op(cmp); cmp.lastTrace(header) }, short)
3272+
3273+
def explaining[T](op: ExplainingTypeComparer => T, short: Boolean)(using Context): T =
3274+
inSubComparer(explainingTypeComparer(short))(op)
32743275

32753276
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
32763277
inSubComparer(matchReducer)(op)
@@ -3440,6 +3441,9 @@ object TypeComparer {
34403441
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
34413442
comparing(_.explained(op, header, short))
34423443

3444+
def explaining[T](op: ExplainingTypeComparer => T, short: Boolean = false)(using Context): T =
3445+
comparing(_.explaining(op, short))
3446+
34433447
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
34443448
comparing(_.reduceMatchWith(op))
34453449

@@ -3871,7 +3875,7 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa
38713875
override def recur(tp1: Type, tp2: Type): Boolean =
38723876
def moreInfo =
38733877
if Config.verboseExplainSubtype || ctx.settings.verbose.value
3874-
then s" ${tp1.getClass} ${tp2.getClass}"
3878+
then s" ${tp1.className} ${tp2.className}"
38753879
else ""
38763880
val approx = approxState
38773881
def approxStr = if short then "" else approx.show

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+15-4
Original file line numberDiff line numberDiff line change
@@ -691,11 +691,22 @@ object TypeOps:
691691
val hiBound = instantiate(bounds.hi, skolemizedArgTypes)
692692
val loBound = instantiate(bounds.lo, skolemizedArgTypes)
693693

694-
def check(using Context) = {
695-
if (!(lo <:< hiBound)) violations += ((arg, "upper", hiBound))
696-
if (!(loBound <:< hi)) violations += ((arg, "lower", loBound))
694+
def check(tp1: Type, tp2: Type, which: String, bound: Type)(using Context) = {
695+
val isSub = TypeComparer.explaining { cmp =>
696+
val isSub = cmp.isSubType(tp1, tp2)
697+
if !isSub then
698+
if !ctx.typerState.constraint.domainLambdas.isEmpty then
699+
typr.println(i"${ctx.typerState.constraint}")
700+
if !ctx.gadt.symbols.isEmpty then
701+
typr.println(i"${ctx.gadt}")
702+
typr.println(cmp.lastTrace(i"checkOverlapsBounds($lo, $hi, $arg, $bounds)($which)"))
703+
//trace.dumpStack()
704+
isSub
705+
}//(using ctx.fresh.setSetting(ctx.settings.verbose, true)) // uncomment to enable moreInfo in ExplainingTypeComparer recur
706+
if !isSub then violations += ((arg, which, bound))
697707
}
698-
check(using checkCtx)
708+
check(lo, hiBound, "upper", hiBound)(using checkCtx)
709+
check(loBound, hi, "lower", loBound)(using checkCtx)
699710
}
700711

701712
def loop(args: List[Tree], boundss: List[TypeBounds]): Unit = args match

compiler/src/dotty/tools/dotc/reporting/trace.scala

+12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@ object trace extends TraceSyntax:
2727
object log extends TraceSyntax:
2828
inline def isEnabled: true = true
2929
protected val isForced = false
30+
31+
def dumpStack(limit: Int = -1): Unit = {
32+
val out = Console.out
33+
val exc = new Exception("Dump Stack")
34+
var stack = exc.getStackTrace
35+
.filter(e => !e.getClassName.startsWith("dotty.tools.dotc.reporting.TraceSyntax"))
36+
.filter(e => !e.getClassName.startsWith("dotty.tools.dotc.reporting.trace"))
37+
if limit >= 0 then
38+
stack = stack.take(limit)
39+
exc.setStackTrace(stack)
40+
exc.printStackTrace(out)
41+
}
3042
end trace
3143

3244
/** This module is carefully optimized to give zero overhead if Config.tracingEnabled

compiler/src/dotty/tools/dotc/typer/Applications.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ trait Applications extends Compatibility {
571571
fail(TypeMismatch(methType.resultType, resultType, None))
572572

573573
// match all arguments with corresponding formal parameters
574-
matchArgs(orderedArgs, methType.paramInfos, 0)
574+
if success then matchArgs(orderedArgs, methType.paramInfos, 0)
575575
case _ =>
576576
if (methType.isError) ok = false
577577
else fail(em"$methString does not take parameters")
@@ -666,7 +666,7 @@ trait Applications extends Compatibility {
666666
* @param n The position of the first parameter in formals in `methType`.
667667
*/
668668
def matchArgs(args: List[Arg], formals: List[Type], n: Int): Unit =
669-
if (success) formals match {
669+
formals match {
670670
case formal :: formals1 =>
671671

672672
def checkNoVarArg(arg: Arg) =
@@ -878,7 +878,9 @@ trait Applications extends Compatibility {
878878
init()
879879

880880
def addArg(arg: Tree, formal: Type): Unit =
881-
typedArgBuf += adapt(arg, formal.widenExpr)
881+
val typedArg = adapt(arg, formal.widenExpr)
882+
typedArgBuf += typedArg
883+
ok = ok & !typedArg.tpe.isError
882884

883885
def makeVarArg(n: Int, elemFormal: Type): Unit = {
884886
val args = typedArgBuf.takeRight(n).toList
@@ -943,7 +945,7 @@ trait Applications extends Compatibility {
943945
var typedArgs = typedArgBuf.toList
944946
def app0 = cpy.Apply(app)(normalizedFun, typedArgs) // needs to be a `def` because typedArgs can change later
945947
val app1 =
946-
if (!success || typedArgs.exists(_.tpe.isError)) app0.withType(UnspecifiedErrorType)
948+
if !success then app0.withType(UnspecifiedErrorType)
947949
else {
948950
if isJavaAnnotConstr(methRef.symbol) then
949951
// #19951 Make sure all arguments are NamedArgs for Java annotations

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+35-20
Original file line numberDiff line numberDiff line change
@@ -240,25 +240,12 @@ object Inferencing {
240240
&& {
241241
var fail = false
242242
var skip = false
243-
val direction = instDirection(tvar.origin)
244-
if minimizeSelected then
245-
if direction <= 0 && tvar.hasLowerBound then
246-
skip = instantiate(tvar, fromBelow = true)
247-
else if direction >= 0 && tvar.hasUpperBound then
248-
skip = instantiate(tvar, fromBelow = false)
249-
// else hold off instantiating unbounded unconstrained variable
250-
else if direction != 0 then
251-
skip = instantiate(tvar, fromBelow = direction < 0)
252-
else if variance >= 0 && tvar.hasLowerBound then
253-
skip = instantiate(tvar, fromBelow = true)
254-
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
255-
&& force.ifBottom == IfBottom.ok
256-
then // if variance == 0, prefer upper bound if one is given
257-
skip = instantiate(tvar, fromBelow = true)
258-
else if variance >= 0 && force.ifBottom == IfBottom.fail then
259-
fail = true
260-
else
261-
toMaximize = tvar :: toMaximize
243+
instDecision(tvar, variance, minimizeSelected, force.ifBottom) match
244+
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
245+
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
246+
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
247+
case Decision.Fail => fail = true
248+
case Decision.ToMax => toMaximize ::= tvar
262249
!fail && (skip || foldOver(x, tvar))
263250
}
264251
case tp => foldOver(x, tp)
@@ -452,9 +439,32 @@ object Inferencing {
452439
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
453440
val approxAbove =
454441
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
442+
//println(i"instDirection($param) = $approxAbove - $approxBelow original=[$original] constrained=[$constrained]")
455443
approxAbove - approxBelow
456444
}
457445

446+
/** The instantiation decision for given poly param computed from the constraint. */
447+
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
448+
private def instDecision(tvar: TypeVar, v: Int, minimizeSelected: Boolean, ifBottom: IfBottom)(using Context): Decision =
449+
import Decision.*
450+
val direction = instDirection(tvar.origin)
451+
val dec = if minimizeSelected then
452+
if direction <= 0 && tvar.hasLowerBound then Min
453+
else if direction >= 0 && tvar.hasUpperBound then Max
454+
else Skip
455+
else if direction != 0 then if direction < 0 then Min else Max
456+
else if tvar.hasLowerBound then if v >= 0 then Min else ToMax
457+
else ifBottom match
458+
// What's left are unconstrained tvars with at most a non-Any param upperbound:
459+
// * IfBottom.flip will always maximise to the param upperbound, for all variances
460+
// * IfBottom.fail will fail the IFD check, for covariant or invariant tvars, maximise contravariant tvars
461+
// * IfBottom.ok will minimise to Nothing covariant and unbounded invariant tvars, and max to Any the others
462+
case IfBottom.ok => if v > 0 || v == 0 && !tvar.hasUpperBound then Min else ToMax // prefer upper bound if one is given
463+
case IfBottom.fail => if v >= 0 then Fail else ToMax
464+
case ifBottom_flip => ToMax
465+
//println(i"instDecision($tvar, v=v, minimizedSelected=$minimizeSelected, $ifBottom) dir=$direction = $dec")
466+
dec
467+
458468
/** Following type aliases and stripping refinements and annotations, if one arrives at a
459469
* class type reference where the class has a companion module, a reference to
460470
* that companion module. Otherwise NoType
@@ -651,7 +661,7 @@ trait Inferencing { this: Typer =>
651661

652662
val ownedVars = state.ownedVars
653663
if (ownedVars ne locked) && !ownedVars.isEmpty then
654-
val qualifying = ownedVars -- locked
664+
val qualifying = (ownedVars -- locked).toList
655665
if (!qualifying.isEmpty) {
656666
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
657667
val resultAlreadyConstrained =
@@ -687,6 +697,10 @@ trait Inferencing { this: Typer =>
687697

688698
def constraint = state.constraint
689699

700+
trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying\n$constraint\n${ctx.gadt}") {
701+
//println(i"$constraint")
702+
//println(i"${ctx.gadt}")
703+
690704
/** Values of this type report type variables to instantiate with variance indication:
691705
* +1 variable appears covariantly, can be instantiated from lower bound
692706
* -1 variable appears contravariantly, can be instantiated from upper bound
@@ -804,6 +818,7 @@ trait Inferencing { this: Typer =>
804818
end doInstantiate
805819

806820
doInstantiate(filterByDeps(toInstantiate))
821+
}
807822
}
808823
end if
809824
tree

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ import config.Printers.typr
1818
import Inferencing.*
1919
import ErrorReporting.*
2020
import util.SourceFile
21+
import util.Spans.{NoSpan, Span}
2122
import TypeComparer.necessarySubType
23+
import reporting.*
2224

2325
import scala.annotation.internal.sharable
24-
import dotty.tools.dotc.util.Spans.{NoSpan, Span}
2526

2627
object ProtoTypes {
2728

@@ -83,6 +84,7 @@ object ProtoTypes {
8384
* fits the given expected result type.
8485
*/
8586
def constrainResult(mt: Type, pt: Type)(using Context): Boolean =
87+
trace(i"constrainResult($mt, $pt)", typr):
8688
val savedConstraint = ctx.typerState.constraint
8789
val res = pt.widenExpr match {
8890
case pt: FunProto =>

compiler/src/dotty/tools/dotc/util/Signatures.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ object Signatures {
651651
*
652652
* @param err The error message to inspect.
653653
* @param params The parameters that were given at the call site.
654-
* @param alreadyCurried Index of paramss we are currently in.
654+
* @param paramssIndex Index of paramss we are currently in.
655655
*
656656
* @return A pair composed of the index of the best alternative (0 if no alternatives
657657
* were found), and the list of alternatives.

compiler/test/dotty/tools/dotc/typer/InstantiateModel.scala

+23-26
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,16 @@ package typer
44

55
// Modelling the decision in IsFullyDefined
66
object InstantiateModel:
7-
enum LB { case NN; case LL; case L1 }; import LB.*
8-
enum UB { case AA; case UU; case U1 }; import UB.*
9-
enum Var { case V; case NotV }; import Var.*
10-
enum MSe { case M; case NotM }; import MSe.*
11-
enum Bot { case Fail; case Ok; case Flip }; import Bot.*
12-
enum Act { case Min; case Max; case ToMax; case Skip; case False }; import Act.*
7+
enum LB { case NN; case LL; case L1 }; import LB.*
8+
enum UB { case AA; case UU; case U1 }; import UB.*
9+
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }; import Decision.*
1310

1411
// NN/AA = Nothing/Any
1512
// LL/UU = the original bounds, on the type parameter
1613
// L1/U1 = the constrained bounds, on the type variable
17-
// V = variance >= 0 ("non-contravariant")
18-
// MSe = minimisedSelected
19-
// Bot = IfBottom
2014
// ToMax = delayed maximisation, via addition to toMaximize
2115
// Skip = minimisedSelected "hold off instantiating"
22-
// False = return false
16+
// Fail = IfBottom.fail's bail option
2317

2418
// there are 9 combinations:
2519
// # | LB | UB | d | // d = direction
@@ -34,24 +28,27 @@ object InstantiateModel:
3428
// 8 | NN | UU | 0 | T <: UU
3529
// 9 | NN | AA | 0 | T
3630

37-
def decide(lb: LB, ub: UB, v: Var, bot: Bot, m: MSe): Act = (lb, ub) match
31+
def instDecision(lb: LB, ub: UB, v: Int, ifBottom: IfBottom, min: Boolean) = (lb, ub) match
3832
case (L1, AA) => Min
3933
case (L1, UU) => Min
4034
case (LL, U1) => Max
4135
case (NN, U1) => Max
4236

43-
case (L1, U1) => if m==M || v==V then Min else ToMax
44-
case (LL, UU) => if m==M || v==V then Min else ToMax
45-
case (LL, AA) => if m==M || v==V then Min else ToMax
46-
47-
case (NN, UU) => bot match
48-
case _ if m==M => Max
49-
//case Ok if v==V => Min // removed, i14218 fix
50-
case Fail if v==V => False
51-
case _ => ToMax
52-
53-
case (NN, AA) => bot match
54-
case _ if m==M => Skip
55-
case Ok if v==V => Min
56-
case Fail if v==V => False
57-
case _ => ToMax
37+
case (L1, U1) => if min then Min else pickVar(v, Min, Min, ToMax)
38+
case (LL, UU) => if min then Min else pickVar(v, Min, Min, ToMax)
39+
case (LL, AA) => if min then Min else pickVar(v, Min, Min, ToMax)
40+
41+
case (NN, UU) => ifBottom match
42+
case _ if min => Max
43+
case IfBottom.ok => pickVar(v, Min, ToMax, ToMax)
44+
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
45+
case IfBottom.flip => ToMax
46+
47+
case (NN, AA) => ifBottom match
48+
case _ if min => Skip
49+
case IfBottom.ok => pickVar(v, Min, Min, ToMax)
50+
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
51+
case IfBottom.flip => ToMax
52+
53+
def pickVar[A](v: Int, cov: A, inv: A, con: A) =
54+
if v > 0 then cov else if v == 0 then inv else con

0 commit comments

Comments
 (0)