Skip to content

Commit 817fe72

Browse files
authored
Merge pull request #3464 from dotty-staging/add-depfuns
Add dependent function types
2 parents 2070b25 + 28a29ea commit 817fe72

19 files changed

+424
-211
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
560560
}
561561
}
562562

563+
override def Inlined(tree: Tree)(call: Tree, bindings: List[MemberDef], expansion: Tree)(implicit ctx: Context): Inlined = {
564+
val tree1 = untpd.cpy.Inlined(tree)(call, bindings, expansion)
565+
tree match {
566+
case tree: Inlined if sameTypes(bindings, tree.bindings) && (expansion.tpe eq tree.expansion.tpe) =>
567+
tree1.withTypeUnchecked(tree.tpe)
568+
case _ => ta.assignType(tree1, bindings, expansion)
569+
}
570+
}
571+
563572
override def SeqLiteral(tree: Tree)(elems: List[Tree], elemtpt: Tree)(implicit ctx: Context): SeqLiteral = {
564573
val tree1 = untpd.cpy.SeqLiteral(tree)(elems, elemtpt)
565574
tree match {

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
8686
case class GenAlias(pat: Tree, expr: Tree) extends Tree
8787
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree]) extends TypTree
8888
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree) extends DefTree
89+
case class DependentTypeTree(tp: List[Symbol] => Type) extends Tree
8990

9091
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY) with WithoutTypeOrPos[Untyped] {
9192
override def isEmpty = true

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ class Definitions {
702702
val tsym = ft.typeSymbol
703703
if (isFunctionClass(tsym)) {
704704
val targs = ft.dealias.argInfos
705-
Some(targs.init, targs.last, tsym.name.isImplicitFunction)
705+
if (targs.isEmpty) None
706+
else Some(targs.init, targs.last, tsym.name.isImplicitFunction)
706707
}
707708
else None
708709
}
@@ -914,13 +915,19 @@ class Definitions {
914915
def isProductSubType(tp: Type)(implicit ctx: Context) =
915916
tp.derivesFrom(ProductType.symbol)
916917

917-
/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ImplicitFunctionN? */
918-
def isFunctionType(tp: Type)(implicit ctx: Context) = {
918+
/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ImplicitFunctionN
919+
* instance?
920+
*/
921+
def isNonDepFunctionType(tp: Type)(implicit ctx: Context) = {
919922
val arity = functionArity(tp)
920923
val sym = tp.dealias.typeSymbol
921924
arity >= 0 && isFunctionClass(sym) && tp.isRef(FunctionType(arity, sym.name.isImplicitFunction).typeSymbol)
922925
}
923926

927+
/** Is `tp` a representation of a (possibly depenent) function type or an alias of such? */
928+
def isFunctionType(tp: Type)(implicit ctx: Context) =
929+
isNonDepFunctionType(tp.dropDependentRefinement)
930+
924931
// Specialized type parameters defined for scala.Function{0,1,2}.
925932
private lazy val Function1SpecializedParams: collection.Set[Type] =
926933
Set(IntType, LongType, FloatType, DoubleType)

compiler/src/dotty/tools/dotc/core/Mode.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@ object Mode {
4848
/** Allow GADTFlexType labelled types to have their bounds adjusted */
4949
val GADTflexible = newMode(8, "GADTflexible")
5050

51-
/** Allow dependent functions. This is currently necessary for unpickling, because
52-
* some dependent functions are passed through from the front end(s?), even though they
53-
* are technically speaking illegal.
54-
*/
55-
val AllowDependentFunctions = newMode(9, "AllowDependentFunctions")
56-
5751
/** We are currently printing something: avoid to produce more logs about
5852
* the printing
5953
*/

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ trait Symbols { this: Context =>
130130
newClassSymbol(owner, name, flags, completer, privateWithin, coord, assocFile)
131131
}
132132

133+
def newRefinedClassSymbol = newCompleteClassSymbol(
134+
ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil)
135+
133136
/** Create a module symbol with associated module class
134137
* from its non-info fields and a function producing the info
135138
* of the module class (this info may be lazy).

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,12 @@ object Types {
10141014
case _ => this
10151015
}
10161016

1017+
/** Dealias, and if result is a dependent function type, drop the `apply` refinement. */
1018+
final def dropDependentRefinement(implicit ctx: Context): Type = dealias match {
1019+
case RefinedType(parent, nme.apply, _) => parent
1020+
case tp => tp
1021+
}
1022+
10171023
/** The type constructor of an applied type, otherwise the type itself */
10181024
final def typeConstructor(implicit ctx: Context): Type = this match {
10191025
case AppliedType(tycon, _) => tycon
@@ -1312,15 +1318,18 @@ object Types {
13121318
// ----- misc -----------------------------------------------------------
13131319

13141320
/** Turn type into a function type.
1315-
* @pre this is a non-dependent method type.
1321+
* @pre this is a method type without parameter dependencies.
13161322
* @param dropLast The number of trailing parameters that should be dropped
13171323
* when forming the function type.
13181324
*/
13191325
def toFunctionType(dropLast: Int = 0)(implicit ctx: Context): Type = this match {
1320-
case mt: MethodType if !mt.isDependent || ctx.mode.is(Mode.AllowDependentFunctions) =>
1326+
case mt: MethodType if !mt.isParamDependent =>
13211327
val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast
1322-
defn.FunctionOf(
1323-
formals1 mapConserve (_.underlyingIfRepeated(mt.isJavaMethod)), mt.resultType, mt.isImplicitMethod && !ctx.erasedTypes)
1328+
val funType = defn.FunctionOf(
1329+
formals1 mapConserve (_.underlyingIfRepeated(mt.isJavaMethod)),
1330+
mt.nonDependentResultApprox, mt.isImplicitMethod && !ctx.erasedTypes)
1331+
if (mt.isDependent) RefinedType(funType, nme.apply, mt)
1332+
else funType
13241333
}
13251334

13261335
/** The signature of this type. This is by default NotAMethod,
@@ -2581,7 +2590,7 @@ object Types {
25812590
def integrate(tparams: List[ParamInfo], tp: Type)(implicit ctx: Context): Type =
25822591
tparams match {
25832592
case LambdaParam(lam, _) :: _ => tp.subst(lam, this)
2584-
case tparams: List[Symbol @unchecked] => tp.subst(tparams, paramRefs)
2593+
case params: List[Symbol @unchecked] => tp.subst(params, paramRefs)
25852594
}
25862595

25872596
final def derivedLambdaType(paramNames: List[ThisName] = this.paramNames,
@@ -2688,7 +2697,7 @@ object Types {
26882697
* def f(x: C)(y: x.S) // dependencyStatus = TrueDeps
26892698
* def f(x: C)(y: x.T) // dependencyStatus = FalseDeps, i.e.
26902699
* // dependency can be eliminated by dealiasing.
2691-
*/
2700+
*/
26922701
private def dependencyStatus(implicit ctx: Context): DependencyStatus = {
26932702
if (myDependencyStatus != Unknown) myDependencyStatus
26942703
else {
@@ -2723,6 +2732,20 @@ object Types {
27232732
def isParamDependent(implicit ctx: Context): Boolean = paramDependencyStatus == TrueDeps
27242733

27252734
def newParamRef(n: Int) = new TermParamRef(this, n) {}
2735+
2736+
/** The least supertype of `resultType` that does not contain parameter dependencies */
2737+
def nonDependentResultApprox(implicit ctx: Context): Type =
2738+
if (isDependent) {
2739+
val dropDependencies = new ApproximatingTypeMap {
2740+
def apply(tp: Type) = tp match {
2741+
case tp @ TermParamRef(thisLambdaType, _) =>
2742+
range(tp.bottomType, atVariance(1)(apply(tp.underlying)))
2743+
case _ => mapOver(tp)
2744+
}
2745+
}
2746+
dropDependencies(resultType)
2747+
}
2748+
else resultType
27262749
}
27272750

27282751
abstract case class MethodType(paramNames: List[TermName])(
@@ -3197,8 +3220,10 @@ object Types {
31973220
case _ => false
31983221
}
31993222

3223+
protected def kindString: String
3224+
32003225
override def toString =
3201-
try s"ParamRef($paramName)"
3226+
try s"${kindString}ParamRef($paramName)"
32023227
catch {
32033228
case ex: IndexOutOfBoundsException => s"ParamRef(<bad index: $paramNum>)"
32043229
}
@@ -3207,8 +3232,9 @@ object Types {
32073232
/** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to
32083233
* refer to `TermParamRef(binder, paramNum)`.
32093234
*/
3210-
abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef {
3235+
abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef with SingletonType {
32113236
type BT = TermLambda
3237+
def kindString = "Term"
32123238
def copyBoundType(bt: BT) = bt.paramRefs(paramNum)
32133239
}
32143240

@@ -3217,6 +3243,7 @@ object Types {
32173243
*/
32183244
abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int) extends ParamRef {
32193245
type BT = TypeLambda
3246+
def kindString = "Type"
32203247
def copyBoundType(bt: BT) = bt.paramRefs(paramNum)
32213248

32223249
/** Looking only at the structure of `bound`, is one of the following true?
@@ -3731,7 +3758,7 @@ object Types {
37313758
// println(s"absMems: ${absMems map (_.show) mkString ", "}")
37323759
if (absMems.size == 1)
37333760
absMems.head.info match {
3734-
case mt: MethodType if !mt.isDependent => Some(absMems.head)
3761+
case mt: MethodType if !mt.isParamDependent => Some(absMems.head)
37353762
case _ => None
37363763
}
37373764
else if (tp isRef defn.PartialFunctionClass)

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
7474
/** The unpickled trees */
7575
def unpickle()(implicit ctx: Context): List[Tree] = {
7676
assert(roots != null, "unpickle without previous enterTopLevel")
77-
new TreeReader(reader).readTopLevel()(ctx.addMode(Mode.AllowDependentFunctions))
77+
new TreeReader(reader).readTopLevel()
7878
}
7979

8080
class Completer(owner: Symbol, reader: TastyReader) extends LazyType {
@@ -999,8 +999,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
999999
val argPats = until(end)(readTerm())
10001000
UnApply(fn, implicitArgs, argPats, patType)
10011001
case REFINEDtpt =>
1002-
val refineCls = ctx.newCompleteClassSymbol(
1003-
ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil)
1002+
val refineCls = ctx.newRefinedClassSymbol
10041003
typeAtAddr(start) = refineCls.typeRef
10051004
val parent = readTpt()
10061005
val refinements = readStats(refineCls, end)(localContext(refineCls))
@@ -1096,7 +1095,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
10961095
class LazyReader[T <: AnyRef](reader: TreeReader, op: TreeReader => Context => T) extends Trees.Lazy[T] {
10971096
def complete(implicit ctx: Context): T = {
10981097
pickling.println(i"starting to read at ${reader.reader.currentAddr}")
1099-
op(reader)(ctx.addMode(Mode.AllowDependentFunctions).withPhaseNoLater(ctx.picklerPhase))
1098+
op(reader)(ctx.withPhaseNoLater(ctx.picklerPhase))
11001099
}
11011100
}
11021101

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ object Parsers {
735735
* | InfixType
736736
* FunArgTypes ::= InfixType
737737
* | `(' [ FunArgType {`,' FunArgType } ] `)'
738+
* | '(' TypedFunParam {',' TypedFunParam } ')'
738739
*/
739740
def typ(): Tree = {
740741
val start = in.offset
@@ -745,6 +746,16 @@ object Parsers {
745746
val t = typ()
746747
if (isImplicit) new ImplicitFunction(params, t) else Function(params, t)
747748
}
749+
def funArgTypesRest(first: Tree, following: () => Tree) = {
750+
val buf = new ListBuffer[Tree] += first
751+
while (in.token == COMMA) {
752+
in.nextToken()
753+
buf += following()
754+
}
755+
buf.toList
756+
}
757+
var isValParamList = false
758+
748759
val t =
749760
if (in.token == LPAREN) {
750761
in.nextToken()
@@ -754,10 +765,19 @@ object Parsers {
754765
}
755766
else {
756767
openParens.change(LPAREN, 1)
757-
val ts = commaSeparated(funArgType)
768+
val paramStart = in.offset
769+
val ts = funArgType() match {
770+
case Ident(name) if name != tpnme.WILDCARD && in.token == COLON =>
771+
isValParamList = true
772+
funArgTypesRest(
773+
typedFunParam(paramStart, name.toTermName),
774+
() => typedFunParam(in.offset, ident()))
775+
case t =>
776+
funArgTypesRest(t, funArgType)
777+
}
758778
openParens.change(LPAREN, -1)
759779
accept(RPAREN)
760-
if (isImplicit || in.token == ARROW) functionRest(ts)
780+
if (isImplicit || isValParamList || in.token == ARROW) functionRest(ts)
761781
else {
762782
for (t <- ts)
763783
if (t.isInstanceOf[ByNameTypeTree])
@@ -790,6 +810,12 @@ object Parsers {
790810
}
791811
}
792812

813+
/** TypedFunParam ::= id ':' Type */
814+
def typedFunParam(start: Offset, name: TermName): Tree = atPos(start) {
815+
accept(COLON)
816+
makeParameter(name, typ(), Modifiers(Param))
817+
}
818+
793819
/** InfixType ::= RefinedType {id [nl] refinedType}
794820
*/
795821
def infixType(): Tree = infixTypeRest(refinedType())

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,14 @@ class PlainPrinter(_ctx: Context) extends Printer {
146146
toTextRef(tp) ~ ".type"
147147
case tp: TermRef if tp.denot.isOverloaded =>
148148
"<overloaded " ~ toTextRef(tp) ~ ">"
149-
case tp: SingletonType =>
150-
toTextLocal(tp.underlying) ~ "(" ~ toTextRef(tp) ~ ")"
151149
case tp: TypeRef =>
152150
toTextPrefix(tp.prefix) ~ selectionString(tp)
151+
case tp: TermParamRef =>
152+
ParamRefNameString(tp) ~ ".type"
153+
case tp: TypeParamRef =>
154+
ParamRefNameString(tp) ~ lambdaHash(tp.binder)
155+
case tp: SingletonType =>
156+
toTextLocal(tp.underlying) ~ "(" ~ toTextRef(tp) ~ ")"
153157
case AppliedType(tycon, args) =>
154158
(toTextLocal(tycon) ~ "[" ~ Text(args map argText, ", ") ~ "]").close
155159
case tp: RefinedType =>
@@ -180,26 +184,19 @@ class PlainPrinter(_ctx: Context) extends Printer {
180184
case NoPrefix =>
181185
"<noprefix>"
182186
case tp: MethodType =>
183-
def paramText(name: TermName, tp: Type) = toText(name) ~ ": " ~ toText(tp)
184187
changePrec(GlobalPrec) {
185-
(if (tp.isImplicitMethod) "(implicit " else "(") ~
186-
Text((tp.paramNames, tp.paramInfos).zipped map paramText, ", ") ~
188+
(if (tp.isImplicitMethod) "(implicit " else "(") ~ paramsText(tp) ~
187189
(if (tp.resultType.isInstanceOf[MethodType]) ")" else "): ") ~
188190
toText(tp.resultType)
189191
}
190192
case tp: ExprType =>
191193
changePrec(GlobalPrec) { "=> " ~ toText(tp.resultType) }
192194
case tp: TypeLambda =>
193-
def paramText(name: Name, bounds: TypeBounds): Text = name.unexpandedName.toString ~ toText(bounds)
194195
changePrec(GlobalPrec) {
195-
"[" ~ Text((tp.paramNames, tp.paramInfos).zipped.map(paramText), ", ") ~
196-
"]" ~ lambdaHash(tp) ~ (" => " provided !tp.resultType.isInstanceOf[MethodType]) ~
196+
"[" ~ paramsText(tp) ~ "]" ~ lambdaHash(tp) ~
197+
(" => " provided !tp.resultType.isInstanceOf[MethodType]) ~
197198
toTextGlobal(tp.resultType)
198199
}
199-
case tp: TypeParamRef =>
200-
ParamRefNameString(tp) ~ lambdaHash(tp.binder)
201-
case tp: TermParamRef =>
202-
ParamRefNameString(tp) ~ ".type"
203200
case AnnotatedType(tpe, annot) =>
204201
toTextLocal(tpe) ~ " " ~ toText(annot)
205202
case tp: TypeVar =>
@@ -221,6 +218,11 @@ class PlainPrinter(_ctx: Context) extends Printer {
221218
}
222219
}.close
223220

221+
protected def paramsText(tp: LambdaType): Text = {
222+
def paramText(name: Name, tp: Type) = toText(name) ~ toTextRHS(tp)
223+
Text((tp.paramNames, tp.paramInfos).zipped.map(paramText), ", ")
224+
}
225+
224226
protected def ParamRefNameString(name: Name): String = name.toString
225227

226228
protected def ParamRefNameString(param: ParamRef): String =

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
116116
override def toText(tp: Type): Text = controlled {
117117
def toTextTuple(args: List[Type]): Text =
118118
"(" ~ Text(args.map(argText), ", ") ~ ")"
119+
119120
def toTextFunction(args: List[Type], isImplicit: Boolean): Text =
120121
changePrec(GlobalPrec) {
121122
val argStr: Text =
@@ -126,6 +127,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
126127
("implicit " provided isImplicit) ~ argStr ~ " => " ~ argText(args.last)
127128
}
128129

130+
def toTextDependentFunction(appType: MethodType): Text = {
131+
("implicit " provided appType.isImplicitMethod) ~
132+
"(" ~ paramsText(appType) ~ ") => " ~ toText(appType.resultType)
133+
}
134+
129135
def isInfixType(tp: Type): Boolean = tp match {
130136
case AppliedType(tycon, args) =>
131137
args.length == 2 &&
@@ -158,6 +164,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
158164
if (isInfixType(tp)) return toTextInfixType(tycon, args)
159165
case EtaExpansion(tycon) =>
160166
return toText(tycon)
167+
case tp: RefinedType if defn.isFunctionType(tp) =>
168+
return toTextDependentFunction(tp.refinedInfo.asInstanceOf[MethodType])
161169
case tp: TypeRef =>
162170
if (tp.symbol.isAnonymousClass && !ctx.settings.uniqid.value)
163171
return toText(tp.info)

compiler/src/dotty/tools/dotc/typer/Dynamic.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ trait Dynamic { self: Typer with Applications =>
144144

145145
tree.tpe.widen match {
146146
case tpe: MethodType =>
147-
if (tpe.isDependent)
148-
fail(i"has a dependent method type")
147+
if (tpe.isParamDependent)
148+
fail(i"has a method type with inter-parameter dependencies")
149149
else if (tpe.paramNames.length > Definitions.MaxStructuralMethodArity)
150150
fail(i"""takes too many parameters.
151151
|Structural types only support methods taking up to ${Definitions.MaxStructuralMethodArity} arguments""")

0 commit comments

Comments
 (0)