Skip to content

Fix #854: Optimize matches on primitive constants as switches. #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 31, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/dotty/tools/dotc/transform/FullParameterization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import NameOps._
import ast._
import ast.Trees._

import scala.reflect.internal.util.Collections

/** Provides methods to produce fully parameterized versions of instance methods,
* where the `this` of the enclosing class is abstracted out in an extra leading
* `$this` parameter and type parameters of the class become additional type
Expand Down Expand Up @@ -86,9 +88,12 @@ trait FullParameterization {
* }
*
* If a self type is present, $this has this self type as its type.
*
* @param abstractOverClass if true, include the type parameters of the class in the method's list of type parameters.
* @param liftThisType if true, require created $this to be $this: (Foo[A] & Foo,this).
* This is needed if created member stays inside scope of Foo(as in tailrec)
*/
def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true)(implicit ctx: Context): Type = {
def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Type = {
val (mtparamCount, origResult) = info match {
case info @ PolyType(mtnames) => (mtnames.length, info.resultType)
case info: ExprType => (0, info.resultType)
Expand All @@ -100,7 +105,8 @@ trait FullParameterization {
/** The method result type */
def resultType(mapClassParams: Type => Type) = {
val thisParamType = mapClassParams(clazz.classInfo.selfType)
MethodType(nme.SELF :: Nil, thisParamType :: Nil)(mt =>
val firstArgType = if (liftThisType) thisParamType & clazz.thisType else thisParamType
MethodType(nme.SELF :: Nil, firstArgType :: Nil)(mt =>
mapClassParams(origResult).substThisUnlessStatic(clazz, MethodParam(mt, 0)))
}

Expand Down Expand Up @@ -217,12 +223,26 @@ trait FullParameterization {
* - the `this` of the enclosing class,
* - the value parameters of the original method `originalDef`.
*/
def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
ref(derived.termRef)
.appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef))
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))
.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
.withPos(originalDef.rhs.pos)
def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Tree = {
val fun =
ref(derived.termRef)
.appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef))
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))

(if (!liftThisType)
fun.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
else {
// this type could have changed on forwarding. Need to insert a cast.
val args = Collections.map2(originalDef.vparamss, fun.tpe.paramTypess)((vparams, paramTypes) =>
Collections.map2(vparams, paramTypes)((vparam, paramType) => {
assert(vparam.tpe <:< paramType.widen) // type should still conform to widened type
ref(vparam.symbol).ensureConforms(paramType)
})
)
fun.appliedToArgss(args)

}).withPos(originalDef.rhs.pos)
}
}

object FullParameterization {
Expand Down
135 changes: 133 additions & 2 deletions src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,139 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans
def optimizeCases(prevBinder: Symbol, cases: List[List[TreeMaker]], pt: Type): (List[List[TreeMaker]], List[Tree])
def analyzeCases(prevBinder: Symbol, cases: List[List[TreeMaker]], pt: Type, suppression: Suppression): Unit = {}

def emitSwitch(scrut: Tree, scrutSym: Symbol, cases: List[List[TreeMaker]], pt: Type, matchFailGenOverride: Option[Symbol => Tree], unchecked: Boolean): Option[Tree] =
None // todo
def emitSwitch(scrut: Tree, scrutSym: Symbol, cases: List[List[TreeMaker]], pt: Type, matchFailGenOverride: Option[Symbol => Tree], unchecked: Boolean): Option[Tree] = {
// TODO Deal with guards?

def isSwitchableType(tpe: Type): Boolean = {
(tpe isRef defn.IntClass) ||
(tpe isRef defn.ByteClass) ||
(tpe isRef defn.ShortClass) ||
(tpe isRef defn.CharClass)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|| (ValueClasses.isDerivedValueClass(tpe.classSymbol) && 
     isSwitchableType(ValueClasses.underlyingOfValueClass(tpe.classSymbol)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I agree it would be nice, this won't happen in practice, because you cannot express a constant value of a user-defined value class in the cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a bit of discussion: let's leave this possibility for future as it's not clear what is a constant for a value class.

}

object IntEqualityTestTreeMaker {
def unapply(treeMaker: EqualityTestTreeMaker): Option[Int] = treeMaker match {
case EqualityTestTreeMaker(`scrutSym`, _, Literal(const), _) =>
if (const.isIntRange) Some(const.intValue)
else None
case _ =>
None
}
}

def isSwitchCase(treeMakers: List[TreeMaker]): Boolean = treeMakers match {
// case 5 =>
case List(IntEqualityTestTreeMaker(_), _: BodyTreeMaker) =>
true

// case 5 | 6 =>
case List(AlternativesTreeMaker(`scrutSym`, alts, _), _: BodyTreeMaker) =>
alts.forall {
case List(IntEqualityTestTreeMaker(_)) => true
case _ => false
}

// case _ =>
case List(_: BodyTreeMaker) =>
true

/* case x @ pat =>
* This includes:
* case x =>
* case x @ 5 =>
* case x @ (5 | 6) =>
*/
case (_: SubstOnlyTreeMaker) :: rest =>
isSwitchCase(rest)

case _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1: Int) match {case x @ 1 => x}

seems not covered by those cases

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not, indeed. I have actually deliberately omitted it, at the moment. It looked useless, as even as a human it would be obvious that x should be 1. I would have implemented it anyway, but for the fact that the sequence of treeMakers was SubstOnly, EqualityTest, Body; and it seemed "dangerous" to move the SubstOnly after the EqualityTest.

I can reconsider this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. It now also support case x @ (5 | 6) =>, which can actually be useful.

false
}

/* (Nil, body) means that `body` is the default case
* It's a bit hacky but it simplifies manipulations.
*/
def extractSwitchCase(treeMakers: List[TreeMaker]): (List[Int], BodyTreeMaker) = treeMakers match {
// case 5 =>
case List(IntEqualityTestTreeMaker(intValue), body: BodyTreeMaker) =>
(List(intValue), body)

// case 5 | 6 =>
case List(AlternativesTreeMaker(_, alts, _), body: BodyTreeMaker) =>
val intValues = alts.map {
case List(IntEqualityTestTreeMaker(intValue)) => intValue
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line causes a bootstrapping test (I think) to fail with

[error] java.lang.AssertionError: assertion failed: error at /home/jenkins/workspace/dotty-master-validate-partest/tests/partest-generated/pos/transform/PatternMatcher.scala:340
[error] type mismatch:
[error]  found   : $this.EqualityTestTreeMaker
[error]  required: TreeMakers.this.EqualityTestTreeMaker
[error] tree = TypeApply(Select(Ident(p38),asInstanceOf),List(TypeTree[TypeRef(TermRef(NoPrefix,$this),EqualityTestTreeMaker)]))
[error]     at scala.Predef$.assert(Predef.scala:165)
[error]     at dotty.tools.dotc.transform.TreeChecker$Checker.adapt(TreeChecker.scala:333)
[error]     at dotty.tools.dotc.typer.ProtoTypes$FunProto.typedArg(ProtoTypes.scala:205)

It seems to be a Ycheck error. Apparently it cannot relate the synthetic $this thing (a capture?) to the enclosing TreeMakers.this.

Not sure what to make of this, atm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK only two transformations introduce $this: ExtensionMethods and TailRec, and extractSwitchCases looks tail-recursive. So my guess is that this method makes pattern-matching and tail-recursiveness interact in an interesting way we haven't tested before (probably because IntEqualityTestTreeMaker is a local class). It'd be great if you could try to reduce this to a simple test case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll to minimize it, yes.

}
(intValues, body)

// case _ =>
case List(body: BodyTreeMaker) =>
(Nil, body)

// case x @ pat =>
case (_: SubstOnlyTreeMaker) :: rest =>
/* Rebindings have been propagated, so the eventual body in `rest`
* contains all the necessary information. The substitution can be
* dropped at this point.
*/
extractSwitchCase(rest)
}

def doOverlap(a: List[Int], b: List[Int]): Boolean =
a.exists(b.contains _)

def makeSwitch(valuesToCases: List[(List[Int], BodyTreeMaker)]): Tree = {
def genBody(body: BodyTreeMaker): Tree = {
val valDefs = body.rebindings.emitValDefs
if (valDefs.isEmpty) body.body
else Block(valDefs, body.body)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right way to avoid useless Blocks over single expressions? I did not find any shortcut. Or do you not bother, in general, and just always create a Block in similar cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment we do not bother in most cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll remove the test. I's a bit ugly because most cases don't have any ValDef, but let's follow the conventions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to leave it as is.
What I meant is that currently different parts of dotty already create a lot of Blocks with empty stats and there's no assumption that stats.nonEmpty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh OK.

}

val intScrut =
if (pt isRef defn.IntClass) ref(scrutSym)
else Select(ref(scrutSym), nme.toInt)

val (normalCases, defaultCaseAndRest) = valuesToCases.span(_._1.nonEmpty)

val newCases = for {
(values, body) <- normalCases
} yield {
val literals = values.map(v => Literal(Constant(v)))
val pat =
if (literals.size == 1) literals.head
else Alternative(literals)
CaseDef(pat, EmptyTree, genBody(body))
}

val catchAllDef = {
if (defaultCaseAndRest.isEmpty) {
matchFailGenOverride.fold[Tree](
Throw(New(defn.MatchErrorType, List(ref(scrutSym)))))(
_(scrutSym))
} else {
/* After the default case, assuming the IR even allows anything,
* things are unreachable anyway and can be removed.
*/
genBody(defaultCaseAndRest.head._2)
}
}
val defaultCase = CaseDef(Underscore(defn.IntType), EmptyTree, catchAllDef)

Match(intScrut, newCases :+ defaultCase)
}

if (isSwitchableType(scrut.tpe.widenDealias) && cases.forall(isSwitchCase)) {
val valuesToCases = cases.map(extractSwitchCase)
val values = valuesToCases.map(_._1)
if (values.tails.exists { tail => tail.nonEmpty && tail.tail.exists(doOverlap(_, tail.head)) }) {
// TODO Deal with overlapping cases (mostly useless without guards)
None
} else {
Some(makeSwitch(valuesToCases))
}
} else {
None
}
}

// for catch (no need to customize match failure)
def emitTypeSwitch(bindersAndCases: List[(Symbol, List[TreeMaker])], pt: Type): Option[List[CaseDef]] =
Expand Down
58 changes: 38 additions & 20 deletions src/dotty/tools/dotc/transform/TailRec.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dotty.tools.dotc.transform

import dotty.tools.dotc.ast.Trees._
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.{TreeTypeMap, tpd}
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
Expand All @@ -10,13 +10,12 @@ import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Types._
import dotty.tools.dotc.core._
import dotty.tools.dotc.transform.TailRec._
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform}
import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo}

/**
* A Tail Rec Transformer
*
* @author Erik Stenman, Iulian Dragos,
* ported to dotty by Dmitry Petrashko
* ported and heavily modified for dotty by Dmitry Petrashko
* @version 1.1
*
* What it does:
Expand Down Expand Up @@ -77,7 +76,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
val name = c.freshName(labelPrefix)

c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass))
if (method.owner.isClass)
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false))
else c.newSymbol(method, name.toTermName, labelFlags, method.info)
}

override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
Expand All @@ -103,18 +104,33 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
// and second one will actually apply,
// now this speculatively transforms tree and throws away result in many cases
val rhsSemiTransformed = {
val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
val rhs = atGroupEnd(transformer.transform(dd.rhs)(_))
rewrote = transformer.rewrote
rhs
}

if (rewrote) {
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel)
Block(List(res), call)
} else {
if (tree.symbol.owner.isClass) {
val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true)
Block(List(labelDef), call)
} else { // inner method. Tail recursion does not change `this`
val labelDef = polyDefDef(label, trefs => vrefss => {
val origMeth = tree.symbol
val origTParams = tree.tparams.map(_.symbol)
val origVParams = tree.vparamss.flatten map (_.symbol)
new TreeTypeMap(
typeMap = identity(_)
.substDealias(origTParams, trefs)
.subst(origVParams, vrefss.flatten.map(_.tpe)),
oldOwners = origMeth :: Nil,
newOwners = label :: Nil
).transform(rhsSemiTransformed)
})
Block(List(labelDef), ref(label).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol)))))
}} else {
if (mandatory)
ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos)
dd.rhs
Expand All @@ -132,7 +148,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete

}

class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {

import dotty.tools.dotc.ast.tpd._

Expand Down Expand Up @@ -175,8 +191,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
case x => (x, x, accArgs, accT, x.symbol)
}

val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
val recv = noTailTransform(reciever)
val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe}
val recv = noTailTransform(prefix)

val targs = typeArguments.map(noTailTransform)
val argumentss = arguments.map(noTailTransforms)
Expand Down Expand Up @@ -215,20 +232,21 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
} else targs

val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef),
List(receiver))
val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef)
val thisPassed = if(this.method.owner.isClass) method appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head)) else method

val res =
if (method.tpe.widen.isParameterless) method
else argumentss.foldLeft(method) {
(met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
}
if (thisPassed.tpe.widen.isParameterless) thisPassed
else argumentss.foldLeft(thisPassed) {
(met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
}
res
}

if (isRecursiveCall) {
if (ctx.tailPos) {
if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call")
else if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
else fail("it changes type of 'this' on a polymorphic recursive call")
}
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/tailcall/t6574.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class Bad[X, Y](val v: Int) extends AnyVal {
println("tail")
}

@annotation.tailrec final def differentTypeArgs : Unit = {
{(); new Bad[String, Unit](0)}.differentTypeArgs
@annotation.tailrec final def differentTypeArgs : Unit = { // error
{(); new Bad[String, Unit](0)}.differentTypeArgs // error
}
}
26 changes: 26 additions & 0 deletions tests/pos/tailcall/i1089.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package hello

import scala.annotation.tailrec

class Enclosing {
class SomeData(val x: Int)

def localDef(): Unit = {
def foo(data: SomeData): Int = data.x

@tailrec
def test(i: Int, data: SomeData): Unit = {
if (i != 0) {
println(foo(data))
test(i - 1, data)
}
}

test(3, new SomeData(42))
}
}

object world extends App {
println("hello dotty!")
new Enclosing().localDef()
}