Skip to content

Commit 28a29ea

Browse files
committed
Improve type inference for dependent function types
Given a dependently typed function value like this one: def f: (x: C) => D => x.T => E we did not propagate information about the subsequent types `D` and `x.T` to the result type of the closure with parameter `(x: C)`. Doing so is a bit tricky because of the dependency. But it's necessary to infer the types of subsequent parameters. Test case: eff-dependent.scala
1 parent b6642e6 commit 28a29ea

File tree

4 files changed

+75
-27
lines changed

4 files changed

+75
-27
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
8686
case class GenAlias(pat: Tree, expr: Tree) extends Tree
8787
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree]) extends TypTree
8888
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree) extends DefTree
89+
case class DependentTypeTree(tp: List[Symbol] => Type) extends Tree
8990

9091
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY) with WithoutTypeOrPos[Untyped] {
9192
override def isEmpty = true

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,8 @@ class Namer { typer: Typer =>
10961096
WildcardType
10971097
case TypeTree() =>
10981098
inferredType
1099+
case DependentTypeTree(tpFun) =>
1100+
tpFun(paramss.head)
10991101
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
11001102
val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe
11011103
mdef match {

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -683,18 +683,37 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
683683
assignType(cpy.If(tree)(cond1, thenp2, elsep2), thenp2, elsep2)
684684
}
685685

686-
private def decomposeProtoFunction(pt: Type, defaultArity: Int)(implicit ctx: Context): (List[Type], Type) = pt match {
687-
case _ if defn.isNonDepFunctionType(pt) =>
688-
// if expected parameter type(s) are wildcards, approximate from below.
689-
// if expected result type is a wildcard, approximate from above.
690-
// this can type the greatest set of admissible closures.
691-
val funType = pt.dealias
692-
(funType.argTypesLo.init, funType.argTypesHi.last)
693-
case SAMType(meth) =>
694-
val mt @ MethodTpe(_, formals, restpe) = meth.info
695-
(formals, if (mt.isDependent) WildcardType else restpe)
696-
case _ =>
697-
(List.tabulate(defaultArity)(alwaysWildcardType), WildcardType)
686+
/** Decompose function prototype into a list of parameter prototypes and a result prototype
687+
* tree, using WildcardTypes where a type is not known.
688+
* For the result type we do this even if the expected type is not fully
689+
* defined, which is a bit of a hack. But it's needed to make the following work
690+
* (see typers.scala and printers/PlainPrinter.scala for examples).
691+
*
692+
* def double(x: Char): String = s"$x$x"
693+
* "abc" flatMap double
694+
*/
695+
private def decomposeProtoFunction(pt: Type, defaultArity: Int)(implicit ctx: Context): (List[Type], untpd.Tree) = {
696+
def typeTree(tp: Type) = tp match {
697+
case _: WildcardType => untpd.TypeTree()
698+
case _ => untpd.TypeTree(tp)
699+
}
700+
pt match {
701+
case _ if defn.isNonDepFunctionType(pt) =>
702+
// if expected parameter type(s) are wildcards, approximate from below.
703+
// if expected result type is a wildcard, approximate from above.
704+
// this can type the greatest set of admissible closures.
705+
val funType = pt.dealias
706+
(funType.argTypesLo.init, typeTree(funType.argTypesHi.last))
707+
case SAMType(meth) =>
708+
val mt @ MethodTpe(_, formals, restpe) = meth.info
709+
(formals,
710+
if (mt.isDependent)
711+
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
712+
else
713+
typeTree(restpe))
714+
case _ =>
715+
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
716+
}
698717
}
699718

700719
def typedFunction(tree: untpd.Function, pt: Type)(implicit ctx: Context) = track("typedFunction") {
@@ -756,7 +775,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
756775
case _ =>
757776
}
758777

759-
val (protoFormals, protoResult) = decomposeProtoFunction(pt, params.length)
778+
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
760779

761780
def refersTo(arg: untpd.Tree, param: untpd.ValDef): Boolean = arg match {
762781
case Ident(name) => name == param.name
@@ -865,19 +884,6 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
865884
else cpy.ValDef(param)(
866885
tpt = untpd.TypeTree(
867886
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
868-
869-
// Define result type of closure as the expected type, thereby pushing
870-
// down any implicit searches. We do this even if the expected type is not fully
871-
// defined, which is a bit of a hack. But it's needed to make the following work
872-
// (see typers.scala and printers/PlainPrinter.scala for examples).
873-
//
874-
// def double(x: Char): String = s"$x$x"
875-
// "abc" flatMap double
876-
//
877-
val resultTpt = protoResult match {
878-
case WildcardType(_) => untpd.TypeTree()
879-
case _ => untpd.TypeTree(protoResult)
880-
}
881887
val inlineable = pt.hasAnnotation(defn.InlineParamAnnot)
882888
desugar.makeClosure(inferredParams, fnBody, resultTpt, inlineable)
883889
}
@@ -1700,7 +1706,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
17001706
case tree: untpd.PackageDef => typedPackageDef(tree)
17011707
case tree: untpd.Annotated => typedAnnotated(tree, pt)
17021708
case tree: untpd.TypedSplice => typedTypedSplice(tree)
1703-
case tree: untpd.UnApply => typedUnApply(tree, pt)
1709+
case tree: untpd.UnApply => typedUnApply(tree, pt)
1710+
case tree: untpd.DependentTypeTree => typed(untpd.TypeTree().withPos(tree.pos), pt)
17041711
case tree @ untpd.PostfixOp(qual, Ident(nme.WILDCARD)) => typedAsFunction(tree, pt)
17051712
case untpd.EmptyTree => tpd.EmptyTree
17061713
case _ => typedUnadapted(desugar(tree), pt)

tests/run/eff-dependent.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
object Test extends App {
2+
3+
trait Effect
4+
5+
// Type X => Y
6+
abstract class Fun[-X, +Y] {
7+
type Eff <: Effect
8+
def apply(x: X): implicit Eff => Y
9+
}
10+
11+
class CanThrow extends Effect
12+
class CanIO extends Effect
13+
14+
val i2s = new Fun[Int, String] { type Eff = CanThrow; def apply(x: Int) = x.toString }
15+
val s2i = new Fun[String, Int] { type Eff = CanIO; def apply(x: String) = x.length }
16+
17+
implicit val ct: CanThrow = new CanThrow
18+
implicit val ci: CanIO = new CanIO
19+
20+
// def map(f: A => B)(xs: List[A]): List[B]
21+
def map[A, B](f: Fun[A, B])(xs: List[A]): implicit f.Eff => List[B] =
22+
xs.map(f.apply)
23+
24+
// def mapFn[A, B]: (A => B) -> List[A] -> List[B]
25+
def mapFn[A, B]: (f: Fun[A, B]) => List[A] => implicit f.Eff => List[B] =
26+
f => xs => map(f)(xs)
27+
28+
// def compose(f: A => B)(g: B => C)(x: A): C
29+
def compose[A, B, C](f: Fun[A, B])(g: Fun[B, C])(x: A): implicit f.Eff => implicit g.Eff => C = g(f(x))
30+
31+
// def composeFn: (A => B) -> (B => C) -> A -> C
32+
def composeFn[A, B, C]: (f: Fun[A, B]) => (g: Fun[B, C]) => A => implicit f.Eff => implicit g.Eff => C =
33+
f => g => x => compose(f)(g)(x)
34+
35+
assert(mapFn(i2s)(List(1, 2, 3)).mkString == "123")
36+
assert(composeFn(i2s)(s2i)(22) == 2)
37+
38+
}

0 commit comments

Comments
 (0)