Skip to content

Commit aa5ed33

Browse files
authored
Merge pull request #12670 from dotty-staging/fix-12661
Always generate a partial function from a lambda
2 parents 92c75ab + 3fad3d3 commit aa5ed33

File tree

3 files changed

+90
-81
lines changed

3 files changed

+90
-81
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+66-69
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
package dotty.tools.dotc
1+
package dotty.tools
2+
package dotc
23
package transform
34

45
import core._
@@ -7,6 +8,7 @@ import MegaPhase._
78
import SymUtils._
89
import NullOpsDecorator._
910
import ast.Trees._
11+
import ast.untpd
1012
import reporting._
1113
import dotty.tools.dotc.util.Spans.Span
1214

@@ -103,78 +105,73 @@ class ExpandSAMs extends MiniPhase:
103105
* ```
104106
*/
105107
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = {
106-
/** An extractor for match, either contained in a block or standalone. */
107-
object PartialFunctionRHS {
108-
def unapply(tree: Tree): Option[Match] = tree match {
109-
case Block(Nil, expr) => unapply(expr)
110-
case m: Match => Some(m)
111-
case _ => None
112-
}
113-
}
114-
115108
val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree
116-
anon.rhs match {
117-
case PartialFunctionRHS(pf) =>
118-
val anonSym = anon.symbol
119-
val anonTpe = anon.tpe.widen
120-
val parents = List(
121-
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
122-
defn.SerializableType)
123-
val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span)
124-
125-
def overrideSym(sym: Symbol) = sym.copy(
126-
owner = pfSym,
127-
flags = Synthetic | Method | Final | Override,
128-
info = tpe.memberInfo(sym),
129-
coord = tree.span).asTerm.entered
130-
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
131-
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
132-
133-
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
134-
val selector = tree.selector
135-
val selectorTpe = selector.tpe.widen
136-
val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe)
137-
val defaultCase =
138-
CaseDef(
139-
Bind(defaultSym, Underscore(selectorTpe)),
140-
EmptyTree,
141-
defaultValue)
142-
val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef)))
143-
cpy.Match(tree)(unchecked, cases :+ defaultCase)
144-
.subst(param.symbol :: Nil, pfParam :: Nil)
145-
// Needed because a partial function can be written as:
146-
// param => param match { case "foo" if foo(param) => param }
147-
// And we need to update all references to 'param'
148-
}
149-
150-
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
151-
val tru = Literal(Constant(true))
152-
def translateCase(cdef: CaseDef) =
153-
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
154-
val paramRef = paramRefss.head.head
155-
val defaultValue = Literal(Constant(false))
156-
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
157-
}
158-
159-
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
160-
val List(paramRef, defaultRef) = paramRefss(1)
161-
def translateCase(cdef: CaseDef) =
162-
cdef.changeOwner(anonSym, applyOrElseFn)
163-
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
164-
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
165-
}
166-
167-
val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered
168-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
169-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
170-
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
171-
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
172109

110+
// The right hand side from which to construct the partial function. This is always a Match.
111+
// If the original rhs is already a Match (possibly in braces), return that.
112+
// Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.
113+
def partialFunRHS(tree: Tree): Match = tree match
114+
case m: Match => m
115+
case Block(Nil, expr) => partialFunRHS(expr)
173116
case _ =>
174-
val found = tpe.baseType(defn.Function1)
175-
report.error(TypeMismatch(found, tpe), tree.srcPos)
176-
tree
117+
Match(ref(param.symbol),
118+
CaseDef(untpd.Ident(nme.WILDCARD).withType(param.symbol.info), EmptyTree, tree) :: Nil)
119+
120+
val pfRHS = partialFunRHS(anon.rhs)
121+
val anonSym = anon.symbol
122+
val anonTpe = anon.tpe.widen
123+
val parents = List(
124+
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
125+
defn.SerializableType)
126+
val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span)
127+
128+
def overrideSym(sym: Symbol) = sym.copy(
129+
owner = pfSym,
130+
flags = Synthetic | Method | Final | Override,
131+
info = tpe.memberInfo(sym),
132+
coord = tree.span).asTerm.entered
133+
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
134+
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
135+
136+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
137+
val selector = tree.selector
138+
val selectorTpe = selector.tpe.widen
139+
val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe)
140+
val defaultCase =
141+
CaseDef(
142+
Bind(defaultSym, Underscore(selectorTpe)),
143+
EmptyTree,
144+
defaultValue)
145+
val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef)))
146+
cpy.Match(tree)(unchecked, cases :+ defaultCase)
147+
.subst(param.symbol :: Nil, pfParam :: Nil)
148+
// Needed because a partial function can be written as:
149+
// param => param match { case "foo" if foo(param) => param }
150+
// And we need to update all references to 'param'
151+
}
152+
153+
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
154+
val tru = Literal(Constant(true))
155+
def translateCase(cdef: CaseDef) =
156+
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
157+
val paramRef = paramRefss.head.head
158+
val defaultValue = Literal(Constant(false))
159+
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
160+
}
161+
162+
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
163+
val List(paramRef, defaultRef) = paramRefss(1)
164+
def translateCase(cdef: CaseDef) =
165+
cdef.changeOwner(anonSym, applyOrElseFn)
166+
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
167+
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177168
}
169+
170+
val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered
171+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
172+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
173+
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
174+
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
178175
}
179176

180177
private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {

tests/neg/i4241.scala

-12
This file was deleted.

tests/run/i4241.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
object Test extends App {
2+
val a: PartialFunction[Int, Int] = { case x => x }
3+
val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case 2 => 2 }
4+
val c: PartialFunction[Int, Int] = x => { x match { case 1 => 1 } }
5+
val d: PartialFunction[Int, Int] = x => { { x match { case 1 => 1 } } }
6+
7+
val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case 1 => 1 } }
8+
val f: PartialFunction[Int, Int] = x => x
9+
val g: PartialFunction[Int, String] = { x => x.toString }
10+
val h: PartialFunction[Int, String] = _.toString
11+
assert(a.isDefinedAt(2))
12+
assert(b.isDefinedAt(2))
13+
assert(!b.isDefinedAt(3))
14+
assert(c.isDefinedAt(1))
15+
assert(!c.isDefinedAt(2))
16+
assert(d.isDefinedAt(1))
17+
assert(!d.isDefinedAt(2))
18+
assert(e.isDefinedAt(2))
19+
assert(f.isDefinedAt(2))
20+
assert(g.isDefinedAt(2))
21+
assert(h.isDefinedAt(2))
22+
}
23+
24+

0 commit comments

Comments
 (0)