Skip to content

Commit ac489db

Browse files
committed
Fix #4177: Generate optimised applyOrElse implementation for partial function literals
1 parent e77604d commit ac489db

File tree

5 files changed

+99
-26
lines changed

5 files changed

+99
-26
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
286286
coord = fns.map(_.pos).reduceLeft(_ union _))
287287
val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered
288288
def forwarder(fn: TermSymbol, name: TermName) = {
289-
val fwdMeth = fn.copy(cls, name, Synthetic | Method).entered.asTerm
290-
DefDef(fwdMeth, prefss => ref(fn).appliedToArgss(prefss))
289+
var flags = Synthetic | Method
290+
def isOverriden(denot: SingleDenotation) = fn.info.overrides(denot.info, matchLoosely = true)
291+
val isOverride = parents.exists(_.member(name).hasAltWith(isOverriden))
292+
if (isOverride) flags = flags | Override
293+
val fwdMeth = fn.copy(cls, name, flags).entered.asTerm
294+
polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypes(tprefs).appliedToArgss(prefss))
291295
}
292296
val forwarders = (fns, methNames).zipped.map(forwarder)
293297
val cdef = ClassDef(cls, DefDef(constr), forwarders)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,10 @@ class Definitions {
585585

586586
lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction")
587587
def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass
588+
589+
lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse)
590+
def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol
591+
588592
lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction")
589593
def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass
590594
lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL")

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import dotty.tools.dotc.util.Positions.Position
1313
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
1414
* These fall into five categories
1515
*
16-
* 1. Partial function closures, we need to generate a isDefinedAt method for these.
16+
* 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
1717
* 2. Closures implementing non-trait classes.
1818
* 3. Closures implementing classes that inherit from a class other than Object
1919
* (a lambda cannot not be a run-time subtype of such a class)
@@ -54,38 +54,70 @@ class ExpandSAMs extends MiniPhase {
5454
val Block(
5555
(applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil,
5656
Closure(_, _, tpt)) = tree
57-
val applyRhs: Tree = applyDef.rhs
57+
58+
def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = {
59+
assert(tree.selector.symbol == param.symbol)
60+
val selectorTpe = selector.tpe.widen
61+
val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe)
62+
val defaultCase =
63+
CaseDef(
64+
Bind(defaultSym, Underscore(selectorTpe)),
65+
EmptyTree,
66+
defaultValue)
67+
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
68+
cpy.Match(tree)(unchecked, cases :+ defaultCase)
69+
.subst(param.symbol :: Nil, selector.symbol :: Nil)
70+
// Needed because a partial function can be written as:
71+
// param => param match { case "foo" if foo(param) => param }
72+
// And we need to update all references to 'param'
73+
}
74+
75+
val applyRhs = applyDef.rhs
5876
val applyFn = applyDef.symbol.asTerm
5977

6078
val MethodTpe(paramNames, paramTypes, _) = applyFn.info
6179
val isDefinedAtFn = applyFn.copy(
6280
name = nme.isDefinedAt,
6381
flags = Synthetic | Method,
6482
info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm
65-
val tru = Literal(Constant(true))
66-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = applyRhs match {
67-
case Match(selector, cases) =>
68-
assert(selector.symbol == param.symbol)
69-
val paramRef = paramRefss.head.head
70-
// Again, the alternative
71-
// val List(List(paramRef)) = paramRefs
72-
// fails with a similar self instantiation error
73-
def translateCase(cdef: CaseDef): CaseDef =
74-
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
75-
val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen)
76-
val defaultCase =
77-
CaseDef(
78-
Bind(defaultSym, Underscore(selector.tpe.widen)),
79-
EmptyTree,
80-
Literal(Constant(false)))
81-
val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType)))
82-
cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
83-
case _ =>
84-
tru
83+
84+
val applyOrElseFn = applyFn.copy(
85+
name = nme.applyOrElse,
86+
flags = Synthetic | Method,
87+
info = tpt.tpe.memberInfo(defn.PartialFunction_applyOrElse)).asTerm
88+
89+
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
90+
val tru = Literal(Constant(true))
91+
applyRhs match {
92+
case tree @ Match(_, cases) =>
93+
def translateCase(cdef: CaseDef)=
94+
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
95+
val paramRef = paramRefss.head.head
96+
val defaultValue = Literal(Constant(false))
97+
translateMatch(tree, paramRef, cases.map(translateCase), defaultValue)
98+
case _ =>
99+
tru
100+
}
101+
}
102+
103+
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
104+
val List(paramRef, defaultRef) = paramRefss.head
105+
applyRhs match {
106+
case tree @ Match(_, cases) =>
107+
def translateCase(cdef: CaseDef) =
108+
cdef.changeOwner(applyFn, applyOrElseFn)
109+
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
110+
translateMatch(tree, paramRef, cases.map(translateCase), defaultValue)
111+
case _ =>
112+
ref(applyFn).appliedTo(paramRef)
113+
}
85114
}
115+
86116
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
87-
val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn), List(nme.apply, nme.isDefinedAt))
88-
cpy.Block(tree)(List(applyDef, isDefinedAtDef), anonCls)
117+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
118+
119+
val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn, applyOrElseFn), List(nme.apply, nme.isDefinedAt, nme.applyOrElse))
120+
cpy.Block(tree)(List(applyDef, isDefinedAtDef, applyOrElseDef), anonCls)
89121
}
90122

91123
private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match {

tests/pos/i4177.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
class Test {
2+
3+
object Foo { def unapply(x: Int) = if (x == 2) Some(x.toString) else None }
4+
5+
def test: Unit = {
6+
val a: PartialFunction[Int, String] = { case Foo(x) => x }
7+
val b: PartialFunction[Int, String] = { case x => x.toString }
8+
val c: PartialFunction[Int, String] = { x => x.toString }
9+
val d: PartialFunction[Int, String] = x => x.toString
10+
11+
val e: PartialFunction[String, String] = { case x @ "abc" => x }
12+
val f: PartialFunction[String, String] = x => x match { case "abc" => x }
13+
val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x }
14+
}
15+
}

tests/run/i4177.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
object Test {
2+
private[this] var count = 0
3+
4+
def test(x: Int) = { count += 1; true }
5+
6+
object Foo {
7+
def unapply(x: Int): Option[Int] = { count += 1; Some(x) }
8+
}
9+
10+
def main(args: Array[String]): Unit = {
11+
val res = List(1, 2).collect { case x if test(x) => x }
12+
assert(count == 2)
13+
14+
count = 0
15+
val res2 = List(1, 2).collect { case Foo(x) => x }
16+
assert(count == 2)
17+
}
18+
}

0 commit comments

Comments
 (0)