diff --git a/src/dotty/tools/dotc/transform/FullParameterization.scala b/src/dotty/tools/dotc/transform/FullParameterization.scala index e9057e885c47..be64df384063 100644 --- a/src/dotty/tools/dotc/transform/FullParameterization.scala +++ b/src/dotty/tools/dotc/transform/FullParameterization.scala @@ -12,6 +12,8 @@ import NameOps._ import ast._ import ast.Trees._ +import scala.reflect.internal.util.Collections + /** Provides methods to produce fully parameterized versions of instance methods, * where the `this` of the enclosing class is abstracted out in an extra leading * `$this` parameter and type parameters of the class become additional type @@ -86,9 +88,12 @@ trait FullParameterization { * } * * If a self type is present, $this has this self type as its type. + * * @param abstractOverClass if true, include the type parameters of the class in the method's list of type parameters. + * @param liftThisType if true, require created $this to be $this: (Foo[A] & Foo,this). + * This is needed if created member stays inside scope of Foo(as in tailrec) */ - def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true)(implicit ctx: Context): Type = { + def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Type = { val (mtparamCount, origResult) = info match { case info @ PolyType(mtnames) => (mtnames.length, info.resultType) case info: ExprType => (0, info.resultType) @@ -100,7 +105,8 @@ trait FullParameterization { /** The method result type */ def resultType(mapClassParams: Type => Type) = { val thisParamType = mapClassParams(clazz.classInfo.selfType) - MethodType(nme.SELF :: Nil, thisParamType :: Nil)(mt => + val firstArgType = if (liftThisType) thisParamType & clazz.thisType else thisParamType + MethodType(nme.SELF :: Nil, firstArgType :: Nil)(mt => mapClassParams(origResult).substThisUnlessStatic(clazz, MethodParam(mt, 0))) } @@ -217,12 +223,26 @@ trait FullParameterization { * - the `this` of the enclosing class, * - the value parameters of the original method `originalDef`. */ - def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree = - ref(derived.termRef) - .appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef)) - .appliedTo(This(originalDef.symbol.enclosingClass.asClass)) - .appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol))) - .withPos(originalDef.rhs.pos) + def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Tree = { + val fun = + ref(derived.termRef) + .appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef)) + .appliedTo(This(originalDef.symbol.enclosingClass.asClass)) + + (if (!liftThisType) + fun.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol))) + else { + // this type could have changed on forwarding. Need to insert a cast. + val args = Collections.map2(originalDef.vparamss, fun.tpe.paramTypess)((vparams, paramTypes) => + Collections.map2(vparams, paramTypes)((vparam, paramType) => { + assert(vparam.tpe <:< paramType.widen) // type should still conform to widened type + ref(vparam.symbol).ensureConforms(paramType) + }) + ) + fun.appliedToArgss(args) + + }).withPos(originalDef.rhs.pos) + } } object FullParameterization { diff --git a/src/dotty/tools/dotc/transform/PatternMatcher.scala b/src/dotty/tools/dotc/transform/PatternMatcher.scala index b4e32fa66524..a7f654780053 100644 --- a/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -303,8 +303,139 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans def optimizeCases(prevBinder: Symbol, cases: List[List[TreeMaker]], pt: Type): (List[List[TreeMaker]], List[Tree]) def analyzeCases(prevBinder: Symbol, cases: List[List[TreeMaker]], pt: Type, suppression: Suppression): Unit = {} - def emitSwitch(scrut: Tree, scrutSym: Symbol, cases: List[List[TreeMaker]], pt: Type, matchFailGenOverride: Option[Symbol => Tree], unchecked: Boolean): Option[Tree] = - None // todo + def emitSwitch(scrut: Tree, scrutSym: Symbol, cases: List[List[TreeMaker]], pt: Type, matchFailGenOverride: Option[Symbol => Tree], unchecked: Boolean): Option[Tree] = { + // TODO Deal with guards? + + def isSwitchableType(tpe: Type): Boolean = { + (tpe isRef defn.IntClass) || + (tpe isRef defn.ByteClass) || + (tpe isRef defn.ShortClass) || + (tpe isRef defn.CharClass) + } + + object IntEqualityTestTreeMaker { + def unapply(treeMaker: EqualityTestTreeMaker): Option[Int] = treeMaker match { + case EqualityTestTreeMaker(`scrutSym`, _, Literal(const), _) => + if (const.isIntRange) Some(const.intValue) + else None + case _ => + None + } + } + + def isSwitchCase(treeMakers: List[TreeMaker]): Boolean = treeMakers match { + // case 5 => + case List(IntEqualityTestTreeMaker(_), _: BodyTreeMaker) => + true + + // case 5 | 6 => + case List(AlternativesTreeMaker(`scrutSym`, alts, _), _: BodyTreeMaker) => + alts.forall { + case List(IntEqualityTestTreeMaker(_)) => true + case _ => false + } + + // case _ => + case List(_: BodyTreeMaker) => + true + + /* case x @ pat => + * This includes: + * case x => + * case x @ 5 => + * case x @ (5 | 6) => + */ + case (_: SubstOnlyTreeMaker) :: rest => + isSwitchCase(rest) + + case _ => + false + } + + /* (Nil, body) means that `body` is the default case + * It's a bit hacky but it simplifies manipulations. + */ + def extractSwitchCase(treeMakers: List[TreeMaker]): (List[Int], BodyTreeMaker) = treeMakers match { + // case 5 => + case List(IntEqualityTestTreeMaker(intValue), body: BodyTreeMaker) => + (List(intValue), body) + + // case 5 | 6 => + case List(AlternativesTreeMaker(_, alts, _), body: BodyTreeMaker) => + val intValues = alts.map { + case List(IntEqualityTestTreeMaker(intValue)) => intValue + } + (intValues, body) + + // case _ => + case List(body: BodyTreeMaker) => + (Nil, body) + + // case x @ pat => + case (_: SubstOnlyTreeMaker) :: rest => + /* Rebindings have been propagated, so the eventual body in `rest` + * contains all the necessary information. The substitution can be + * dropped at this point. + */ + extractSwitchCase(rest) + } + + def doOverlap(a: List[Int], b: List[Int]): Boolean = + a.exists(b.contains _) + + def makeSwitch(valuesToCases: List[(List[Int], BodyTreeMaker)]): Tree = { + def genBody(body: BodyTreeMaker): Tree = { + val valDefs = body.rebindings.emitValDefs + if (valDefs.isEmpty) body.body + else Block(valDefs, body.body) + } + + val intScrut = + if (pt isRef defn.IntClass) ref(scrutSym) + else Select(ref(scrutSym), nme.toInt) + + val (normalCases, defaultCaseAndRest) = valuesToCases.span(_._1.nonEmpty) + + val newCases = for { + (values, body) <- normalCases + } yield { + val literals = values.map(v => Literal(Constant(v))) + val pat = + if (literals.size == 1) literals.head + else Alternative(literals) + CaseDef(pat, EmptyTree, genBody(body)) + } + + val catchAllDef = { + if (defaultCaseAndRest.isEmpty) { + matchFailGenOverride.fold[Tree]( + Throw(New(defn.MatchErrorType, List(ref(scrutSym)))))( + _(scrutSym)) + } else { + /* After the default case, assuming the IR even allows anything, + * things are unreachable anyway and can be removed. + */ + genBody(defaultCaseAndRest.head._2) + } + } + val defaultCase = CaseDef(Underscore(defn.IntType), EmptyTree, catchAllDef) + + Match(intScrut, newCases :+ defaultCase) + } + + if (isSwitchableType(scrut.tpe.widenDealias) && cases.forall(isSwitchCase)) { + val valuesToCases = cases.map(extractSwitchCase) + val values = valuesToCases.map(_._1) + if (values.tails.exists { tail => tail.nonEmpty && tail.tail.exists(doOverlap(_, tail.head)) }) { + // TODO Deal with overlapping cases (mostly useless without guards) + None + } else { + Some(makeSwitch(valuesToCases)) + } + } else { + None + } + } // for catch (no need to customize match failure) def emitTypeSwitch(bindersAndCases: List[(Symbol, List[TreeMaker])], pt: Type): Option[List[CaseDef]] = diff --git a/src/dotty/tools/dotc/transform/TailRec.scala b/src/dotty/tools/dotc/transform/TailRec.scala index 58fe7a6c909c..23686b522be3 100644 --- a/src/dotty/tools/dotc/transform/TailRec.scala +++ b/src/dotty/tools/dotc/transform/TailRec.scala @@ -1,7 +1,7 @@ package dotty.tools.dotc.transform import dotty.tools.dotc.ast.Trees._ -import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.{TreeTypeMap, tpd} import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Decorators._ import dotty.tools.dotc.core.DenotTransformers.DenotTransformer @@ -10,13 +10,12 @@ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.core._ import dotty.tools.dotc.transform.TailRec._ -import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform} +import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo} /** * A Tail Rec Transformer - * * @author Erik Stenman, Iulian Dragos, - * ported to dotty by Dmitry Petrashko + * ported and heavily modified for dotty by Dmitry Petrashko * @version 1.1 * * What it does: @@ -77,7 +76,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = { val name = c.freshName(labelPrefix) - c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass)) + if (method.owner.isClass) + c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false)) + else c.newSymbol(method, name.toTermName, labelFlags, method.info) } override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { @@ -103,7 +104,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete // and second one will actually apply, // now this speculatively transforms tree and throws away result in many cases val rhsSemiTransformed = { - val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel) + val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel) val rhs = atGroupEnd(transformer.transform(dd.rhs)(_)) rewrote = transformer.rewrote rhs @@ -111,10 +112,25 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete if (rewrote) { val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed) - val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel) - val call = forwarder(label, dd, abstractOverClass = defIsTopLevel) - Block(List(res), call) - } else { + if (tree.symbol.owner.isClass) { + val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel) + val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true) + Block(List(labelDef), call) + } else { // inner method. Tail recursion does not change `this` + val labelDef = polyDefDef(label, trefs => vrefss => { + val origMeth = tree.symbol + val origTParams = tree.tparams.map(_.symbol) + val origVParams = tree.vparamss.flatten map (_.symbol) + new TreeTypeMap( + typeMap = identity(_) + .substDealias(origTParams, trefs) + .subst(origVParams, vrefss.flatten.map(_.tpe)), + oldOwners = origMeth :: Nil, + newOwners = label :: Nil + ).transform(rhsSemiTransformed) + }) + Block(List(labelDef), ref(label).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))) + }} else { if (mandatory) ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos) dd.rhs @@ -132,7 +148,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete } - class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap { + class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap { import dotty.tools.dotc.ast.tpd._ @@ -175,8 +191,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete case x => (x, x, accArgs, accT, x.symbol) } - val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) - val recv = noTailTransform(reciever) + val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) + val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe} + val recv = noTailTransform(prefix) val targs = typeArguments.map(noTailTransform) val argumentss = arguments.map(noTailTransforms) @@ -215,20 +232,21 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete targs ::: classTypeArgs.map(x => ref(x.typeSymbol)) } else targs - val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef), - List(receiver)) + val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef) + val thisPassed = if(this.method.owner.isClass) method appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head)) else method val res = - if (method.tpe.widen.isParameterless) method - else argumentss.foldLeft(method) { - (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. - } + if (thisPassed.tpe.widen.isParameterless) thisPassed + else argumentss.foldLeft(thisPassed) { + (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. + } res } if (isRecursiveCall) { if (ctx.tailPos) { - if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) + if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call") + else if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) else fail("it changes type of 'this' on a polymorphic recursive call") } diff --git a/tests/neg/tailcall/t6574.scala b/tests/neg/tailcall/t6574.scala index 7030b3b4ad05..d9ba2882ddab 100644 --- a/tests/neg/tailcall/t6574.scala +++ b/tests/neg/tailcall/t6574.scala @@ -4,7 +4,7 @@ class Bad[X, Y](val v: Int) extends AnyVal { println("tail") } - @annotation.tailrec final def differentTypeArgs : Unit = { - {(); new Bad[String, Unit](0)}.differentTypeArgs + @annotation.tailrec final def differentTypeArgs : Unit = { // error + {(); new Bad[String, Unit](0)}.differentTypeArgs // error } } diff --git a/tests/pos/tailcall/i1089.scala b/tests/pos/tailcall/i1089.scala new file mode 100644 index 000000000000..8eb69cb9bb75 --- /dev/null +++ b/tests/pos/tailcall/i1089.scala @@ -0,0 +1,26 @@ +package hello + +import scala.annotation.tailrec + +class Enclosing { + class SomeData(val x: Int) + + def localDef(): Unit = { + def foo(data: SomeData): Int = data.x + + @tailrec + def test(i: Int, data: SomeData): Unit = { + if (i != 0) { + println(foo(data)) + test(i - 1, data) + } + } + + test(3, new SomeData(42)) + } +} + +object world extends App { + println("hello dotty!") + new Enclosing().localDef() +}