Skip to content

Commit 8d85cd7

Browse files
committed
Merge pull request #1091 from dotty-staging/fix1089
FullParametrization: allow to have $this of ThisType.
2 parents 80b1247 + 2fe8ad5 commit 8d85cd7

File tree

4 files changed

+94
-30
lines changed

4 files changed

+94
-30
lines changed

src/dotty/tools/dotc/transform/FullParameterization.scala

+28-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import NameOps._
1212
import ast._
1313
import ast.Trees._
1414

15+
import scala.reflect.internal.util.Collections
16+
1517
/** Provides methods to produce fully parameterized versions of instance methods,
1618
* where the `this` of the enclosing class is abstracted out in an extra leading
1719
* `$this` parameter and type parameters of the class become additional type
@@ -86,9 +88,12 @@ trait FullParameterization {
8688
* }
8789
*
8890
* If a self type is present, $this has this self type as its type.
91+
*
8992
* @param abstractOverClass if true, include the type parameters of the class in the method's list of type parameters.
93+
* @param liftThisType if true, require created $this to be $this: (Foo[A] & Foo,this).
94+
* This is needed if created member stays inside scope of Foo(as in tailrec)
9095
*/
91-
def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true)(implicit ctx: Context): Type = {
96+
def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Type = {
9297
val (mtparamCount, origResult) = info match {
9398
case info @ PolyType(mtnames) => (mtnames.length, info.resultType)
9499
case info: ExprType => (0, info.resultType)
@@ -100,7 +105,8 @@ trait FullParameterization {
100105
/** The method result type */
101106
def resultType(mapClassParams: Type => Type) = {
102107
val thisParamType = mapClassParams(clazz.classInfo.selfType)
103-
MethodType(nme.SELF :: Nil, thisParamType :: Nil)(mt =>
108+
val firstArgType = if (liftThisType) thisParamType & clazz.thisType else thisParamType
109+
MethodType(nme.SELF :: Nil, firstArgType :: Nil)(mt =>
104110
mapClassParams(origResult).substThisUnlessStatic(clazz, MethodParam(mt, 0)))
105111
}
106112

@@ -217,12 +223,26 @@ trait FullParameterization {
217223
* - the `this` of the enclosing class,
218224
* - the value parameters of the original method `originalDef`.
219225
*/
220-
def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
221-
ref(derived.termRef)
222-
.appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef))
223-
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))
224-
.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
225-
.withPos(originalDef.rhs.pos)
226+
def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Tree = {
227+
val fun =
228+
ref(derived.termRef)
229+
.appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef))
230+
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))
231+
232+
(if (!liftThisType)
233+
fun.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
234+
else {
235+
// this type could have changed on forwarding. Need to insert a cast.
236+
val args = Collections.map2(originalDef.vparamss, fun.tpe.paramTypess)((vparams, paramTypes) =>
237+
Collections.map2(vparams, paramTypes)((vparam, paramType) => {
238+
assert(vparam.tpe <:< paramType.widen) // type should still conform to widened type
239+
ref(vparam.symbol).ensureConforms(paramType)
240+
})
241+
)
242+
fun.appliedToArgss(args)
243+
244+
}).withPos(originalDef.rhs.pos)
245+
}
226246
}
227247

228248
object FullParameterization {

src/dotty/tools/dotc/transform/TailRec.scala

+38-20
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package dotty.tools.dotc.transform
22

33
import dotty.tools.dotc.ast.Trees._
4-
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.ast.{TreeTypeMap, tpd}
55
import dotty.tools.dotc.core.Contexts.Context
66
import dotty.tools.dotc.core.Decorators._
77
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
@@ -10,13 +10,12 @@ import dotty.tools.dotc.core.Symbols._
1010
import dotty.tools.dotc.core.Types._
1111
import dotty.tools.dotc.core._
1212
import dotty.tools.dotc.transform.TailRec._
13-
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform}
13+
import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo}
1414

1515
/**
1616
* A Tail Rec Transformer
17-
*
1817
* @author Erik Stenman, Iulian Dragos,
19-
* ported to dotty by Dmitry Petrashko
18+
* ported and heavily modified for dotty by Dmitry Petrashko
2019
* @version 1.1
2120
*
2221
* What it does:
@@ -77,7 +76,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
7776
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
7877
val name = c.freshName(labelPrefix)
7978

80-
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass))
79+
if (method.owner.isClass)
80+
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false))
81+
else c.newSymbol(method, name.toTermName, labelFlags, method.info)
8182
}
8283

8384
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
@@ -103,18 +104,33 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
103104
// and second one will actually apply,
104105
// now this speculatively transforms tree and throws away result in many cases
105106
val rhsSemiTransformed = {
106-
val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
107+
val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
107108
val rhs = atGroupEnd(transformer.transform(dd.rhs)(_))
108109
rewrote = transformer.rewrote
109110
rhs
110111
}
111112

112113
if (rewrote) {
113114
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
114-
val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
115-
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel)
116-
Block(List(res), call)
117-
} else {
115+
if (tree.symbol.owner.isClass) {
116+
val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
117+
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true)
118+
Block(List(labelDef), call)
119+
} else { // inner method. Tail recursion does not change `this`
120+
val labelDef = polyDefDef(label, trefs => vrefss => {
121+
val origMeth = tree.symbol
122+
val origTParams = tree.tparams.map(_.symbol)
123+
val origVParams = tree.vparamss.flatten map (_.symbol)
124+
new TreeTypeMap(
125+
typeMap = identity(_)
126+
.substDealias(origTParams, trefs)
127+
.subst(origVParams, vrefss.flatten.map(_.tpe)),
128+
oldOwners = origMeth :: Nil,
129+
newOwners = label :: Nil
130+
).transform(rhsSemiTransformed)
131+
})
132+
Block(List(labelDef), ref(label).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol)))))
133+
}} else {
118134
if (mandatory)
119135
ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos)
120136
dd.rhs
@@ -132,7 +148,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
132148

133149
}
134150

135-
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
151+
class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
136152

137153
import dotty.tools.dotc.ast.tpd._
138154

@@ -175,8 +191,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
175191
case x => (x, x, accArgs, accT, x.symbol)
176192
}
177193

178-
val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
179-
val recv = noTailTransform(reciever)
194+
val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
195+
val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe}
196+
val recv = noTailTransform(prefix)
180197

181198
val targs = typeArguments.map(noTailTransform)
182199
val argumentss = arguments.map(noTailTransforms)
@@ -215,20 +232,21 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
215232
targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
216233
} else targs
217234

218-
val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef),
219-
List(receiver))
235+
val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef)
236+
val thisPassed = if(this.method.owner.isClass) method appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head)) else method
220237

221238
val res =
222-
if (method.tpe.widen.isParameterless) method
223-
else argumentss.foldLeft(method) {
224-
(met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
225-
}
239+
if (thisPassed.tpe.widen.isParameterless) thisPassed
240+
else argumentss.foldLeft(thisPassed) {
241+
(met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
242+
}
226243
res
227244
}
228245

229246
if (isRecursiveCall) {
230247
if (ctx.tailPos) {
231-
if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
248+
if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call")
249+
else if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
232250
else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
233251
else fail("it changes type of 'this' on a polymorphic recursive call")
234252
}

tests/neg/tailcall/t6574.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ class Bad[X, Y](val v: Int) extends AnyVal {
44
println("tail")
55
}
66

7-
@annotation.tailrec final def differentTypeArgs : Unit = {
8-
{(); new Bad[String, Unit](0)}.differentTypeArgs
7+
@annotation.tailrec final def differentTypeArgs : Unit = { // error
8+
{(); new Bad[String, Unit](0)}.differentTypeArgs // error
99
}
1010
}

tests/pos/tailcall/i1089.scala

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package hello
2+
3+
import scala.annotation.tailrec
4+
5+
class Enclosing {
6+
class SomeData(val x: Int)
7+
8+
def localDef(): Unit = {
9+
def foo(data: SomeData): Int = data.x
10+
11+
@tailrec
12+
def test(i: Int, data: SomeData): Unit = {
13+
if (i != 0) {
14+
println(foo(data))
15+
test(i - 1, data)
16+
}
17+
}
18+
19+
test(3, new SomeData(42))
20+
}
21+
}
22+
23+
object world extends App {
24+
println("hello dotty!")
25+
new Enclosing().localDef()
26+
}

0 commit comments

Comments
 (0)