Skip to content

Commit 89d90e6

Browse files
committed
Trial: New ElimByName phase
1 parent f8d060d commit 89d90e6

File tree

6 files changed

+168
-1
lines changed

6 files changed

+168
-1
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class Compiler {
6969
new InlineVals, // Check right hand-sides of an `inline val`s
7070
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
7171
new ElimRepeated) :: // Rewrite vararg parameters and arguments
72+
List(new ElimByNameParams) ::
7273
List(new init.Checker) :: // Check initialization of objects
7374
List(new ProtectedAccessors, // Add accessors for protected members
7475
new ExtensionMethods, // Expand methods of value classes with extension methods

compiler/src/dotty/tools/dotc/core/Definitions.scala

+20
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,24 @@ class Definitions {
10811081
}
10821082
}
10831083

1084+
object ByNameFunction:
1085+
def apply(tp: Type)(using Context): Type =
1086+
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
1087+
def unapply(tp: Type)(using Context): Option[Type] = tp match
1088+
case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) =>
1089+
Some(arg)
1090+
case tp @ AnnotatedType(parent, _) =>
1091+
unapply(parent)
1092+
case _ =>
1093+
None
1094+
1095+
final def isByNameFunctionClass(sym: Symbol): Boolean =
1096+
sym eq ContextFunction0
1097+
1098+
def isByNameFunction(tp: Type)(using Context): Boolean = tp match
1099+
case ByNameFunction(_) => true
1100+
case _ => false
1101+
10841102
final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
10851103
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass
10861104

@@ -1294,10 +1312,12 @@ class Definitions {
12941312
).symbol.asClass
12951313

12961314
@tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply)
1315+
@tu lazy val ContextFunction0_apply: Symbol = ContextFunction0.requiredMethod(nme.apply)
12971316

12981317
@tu lazy val Function0: Symbol = FunctionClass(0)
12991318
@tu lazy val Function1: Symbol = FunctionClass(1)
13001319
@tu lazy val Function2: Symbol = FunctionClass(2)
1320+
@tu lazy val ContextFunction0: Symbol = FunctionClass(0, isContextual = true)
13011321

13021322
def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): TypeRef =
13031323
FunctionClass(n, isContextual && !ctx.erasedTypes, isErased).typeRef

compiler/src/dotty/tools/dotc/core/Phases.scala

+4
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ object Phases {
207207
private var myRefChecksPhase: Phase = _
208208
private var myPatmatPhase: Phase = _
209209
private var myElimRepeatedPhase: Phase = _
210+
private var myElimByNamePhase: Phase = _
210211
private var myExtensionMethodsPhase: Phase = _
211212
private var myExplicitOuterPhase: Phase = _
212213
private var myGettersPhase: Phase = _
@@ -229,6 +230,7 @@ object Phases {
229230
final def refchecksPhase: Phase = myRefChecksPhase
230231
final def patmatPhase: Phase = myPatmatPhase
231232
final def elimRepeatedPhase: Phase = myElimRepeatedPhase
233+
final def elimByNamePhase: Phase = myElimByNamePhase
232234
final def extensionMethodsPhase: Phase = myExtensionMethodsPhase
233235
final def explicitOuterPhase: Phase = myExplicitOuterPhase
234236
final def gettersPhase: Phase = myGettersPhase
@@ -253,6 +255,7 @@ object Phases {
253255
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
254256
myRefChecksPhase = phaseOfClass(classOf[RefChecks])
255257
myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated])
258+
myElimByNamePhase = phaseOfClass(classOf[ElimByNameParams])
256259
myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods])
257260
myErasurePhase = phaseOfClass(classOf[Erasure])
258261
myElimErasedValueTypePhase = phaseOfClass(classOf[ElimErasedValueType])
@@ -427,6 +430,7 @@ object Phases {
427430
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
428431
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
429432
def elimRepeatedPhase(using Context): Phase = ctx.base.elimRepeatedPhase
433+
def elimByNamePhase(using Context): Phase = ctx.base.elimByNamePhase
430434
def extensionMethodsPhase(using Context): Phase = ctx.base.extensionMethodsPhase
431435
def explicitOuterPhase(using Context): Phase = ctx.base.explicitOuterPhase
432436
def gettersPhase(using Context): Phase = ctx.base.gettersPhase

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ object StdNames {
184184
// ----- Type names -----------------------------------------
185185

186186
final val BYNAME_PARAM_CLASS: N = "<byname>"
187+
final val BYNAME_PARAM_FUN: N = "<function0-byname>"
187188
final val EQUALS_PATTERN: N = "<equals>"
188189
final val LOCAL_CHILD: N = "<local child>"
189190
final val REPEATED_PARAM_CLASS: N = "<repeated>"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
846846
case _ => tp2.isAnyRef
847847
}
848848
compareJavaArray
849-
case tp1: ExprType if ctx.phase.id > gettersPhase.id =>
849+
case tp1: ExprType if ctx.phaseId > gettersPhase.id =>
850850
// getters might have converted T to => T, need to compensate.
851851
recur(tp1.widenExpr, tp2)
852852
case _ =>
@@ -1510,6 +1510,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
15101510
case _ => arg1
15111511
}
15121512
arg2.contains(arg1norm)
1513+
case ExprType(arg2res)
1514+
if ctx.phaseId > ctx.base.elimByNamePhase.id && !ctx.erasedTypes
1515+
&& defn.isByNameFunction(arg1) =>
1516+
// ElimByName maps `=> T` to `()? => T`, but only in method parameters. It leaves
1517+
// embedded `=> T` alone. This clause needs to compensate for that.
1518+
isSubArg(arg1.argInfos.head, arg2res)
15131519
case _ =>
15141520
arg1 match {
15151521
case arg1: TypeBounds =>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import core._
6+
import Contexts._
7+
import Symbols._
8+
import Types._
9+
import Flags._
10+
import SymDenotations.*
11+
import DenotTransformers.InfoTransformer
12+
import NameKinds.SuperArgName
13+
import core.StdNames.nme
14+
import MegaPhase.*
15+
import Decorators.*
16+
import reporting.trace
17+
18+
/** This phase translates arguments to call-by-name parameters, using the rules
19+
*
20+
* x ==> x if x is a => parameter
21+
* e.apply() ==> <cbn-arg>(e) if e is pure
22+
* e ==> <cbn-arg>(() => e) for all other arguments
23+
*
24+
* where
25+
*
26+
* <cbn-arg>: [T](() => T): T
27+
*
28+
* is a synthetic method defined in Definitions. Erasure will later strip the <cbn-arg> wrappers.
29+
*/
30+
class ElimByNameParams extends MiniPhase, InfoTransformer:
31+
thisPhase =>
32+
33+
import ast.tpd._
34+
35+
override def phaseName: String = ElimByNameParams.name
36+
37+
override def runsAfterGroupsOf: Set[String] = Set(ExpandSAMs.name, ElimRepeated.name)
38+
// - ExpanSAMs applied to partial functions creates methods that need
39+
// to be fully defined before converting. Test case is pos/i9391.scala.
40+
// - ByNameLambda needs to run in a group after ElimRepeated since ElimRepeated
41+
// works on simple arguments but not converted closures, and it sees the arguments
42+
// after transformations by subsequent miniphases in the same group.
43+
44+
override def changesParents: Boolean = true
45+
// Expr types in parent type arguments are changed to function types.
46+
47+
/** If denotation had an ExprType before, it now gets a function type */
48+
private def exprBecomesFunction(symd: SymDenotation)(using Context): Boolean =
49+
symd.is(Param) || symd.is(ParamAccessor, butNot = Method)
50+
51+
def transformInfo(tp: Type, sym: Symbol)(using Context): Type = tp match {
52+
case ExprType(rt) if exprBecomesFunction(sym) =>
53+
defn.ByNameFunction(rt)
54+
case tp: MethodType =>
55+
def exprToFun(tp: Type) = tp match
56+
case ExprType(rt) => defn.ByNameFunction(rt)
57+
case tp => tp
58+
tp.derivedLambdaType(
59+
paramInfos = tp.paramInfos.mapConserve(exprToFun),
60+
resType = transformInfo(tp.resType, sym))
61+
case tp: PolyType =>
62+
tp.derivedLambdaType(resType = transformInfo(tp.resType, sym))
63+
case _ => tp
64+
}
65+
66+
override def infoMayChange(sym: Symbol)(using Context): Boolean =
67+
sym.is(Method) || exprBecomesFunction(sym)
68+
69+
def byNameClosure(arg: Tree, argType: Type)(using Context): Tree =
70+
val meth = newAnonFun(ctx.owner, MethodType(Nil, argType), coord = arg.span)
71+
Closure(meth,
72+
_ => arg.changeOwnerAfter(ctx.owner, meth, thisPhase),
73+
targetType = defn.ByNameFunction(argType)
74+
).withSpan(arg.span)
75+
76+
private def isByNameRef(tree: Tree)(using Context): Boolean =
77+
defn.isByNameFunction(tree.tpe.widen)
78+
79+
/** Map `tree` to `tree.apply()` is `tree` is of type `() ?=> T` */
80+
private def applyIfFunction(tree: Tree)(using Context) =
81+
if isByNameRef(tree) then
82+
val tree0 = transformFollowing(tree)
83+
atPhase(next) { tree0.select(defn.ContextFunction0_apply).appliedToNone }
84+
else tree
85+
86+
override def transformIdent(tree: Ident)(using Context): Tree =
87+
applyIfFunction(tree)
88+
89+
override def transformSelect(tree: Select)(using Context): Tree =
90+
applyIfFunction(tree)
91+
92+
override def transformTypeApply(tree: TypeApply)(using Context): Tree = tree match {
93+
case TypeApply(Select(_, nme.asInstanceOf_), arg :: Nil) =>
94+
// tree might be of form e.asInstanceOf[x.type] where x becomes a function.
95+
// See pos/t296.scala
96+
applyIfFunction(tree)
97+
case _ => tree
98+
}
99+
100+
override def transformApply(tree: Apply)(using Context): Tree =
101+
trace(s"transforming ${tree.show} at phase ${ctx.phase}", show = true) {
102+
103+
def transformArg(arg: Tree, formal: Type): Tree = formal match
104+
case defn.ByNameFunction(formalResult) =>
105+
def stripTyped(t: Tree): Tree = t match
106+
case Typed(expr, _) => stripTyped(expr)
107+
case _ => t
108+
stripTyped(arg) match
109+
case Apply(Select(qual, nme.apply), Nil)
110+
if isByNameRef(qual) && (isPureExpr(qual) || qual.symbol.isAllOf(InlineParam)) =>
111+
qual
112+
case _ =>
113+
if isByNameRef(arg) || arg.symbol.name.is(SuperArgName)
114+
then arg
115+
else
116+
var argType = arg.tpe.widenIfUnstable
117+
if argType.isBottomType then argType = formalResult
118+
byNameClosure(arg, argType)
119+
case _ =>
120+
arg
121+
122+
val mt @ MethodType(_) = tree.fun.tpe.widen
123+
val args1 = tree.args.zipWithConserve(mt.paramInfos)(transformArg)
124+
cpy.Apply(tree)(tree.fun, args1)
125+
}
126+
127+
override def transformValDef(tree: ValDef)(using Context): Tree =
128+
atPhase(next) {
129+
if exprBecomesFunction(tree.symbol) then
130+
cpy.ValDef(tree)(tpt = tree.tpt.withType(tree.symbol.info))
131+
else tree
132+
}
133+
134+
object ElimByNameParams:
135+
val name: String = "elimByNameParams"

0 commit comments

Comments
 (0)