Skip to content

Commit e9971c4

Browse files
smarterretronymlrytz
committed
Fix #4442: Add support for lambda serialization
In Scala, lambdas whose SAM extend Serializable as well as lambdas whose SAM is simply scala.Function* should be serializable but this was not the case in Dotty so far. On the JVM, lambdas instantiated using invokedynamic calls require some special handling to be serializable: 1. We need to use the invokedynamic bootstrap method `LambdaMetaFactory#altMetafactory` instead of `LambdaMetaFactory#metafactory`, this allows us to pass the FLAG_SERIALIZABLE flag. This is implemented in the backend submodule commit included in this commit (see lampepfl/scala#39). 2. In the enclosing class where the lambda is defined, a $deserializeLambda$ method needs to be generated, this is implemented in this commit. Most of the logic for $deserializeLambda$ is implemented in the Scala 2.12 standard libraries class scala.runtime.LambdaDeserialize and scala.runtime.LambdaDeserializer which can be used here as-is, the only logic we actually need to implement here is: 1. In `collectSerializableLambdas`, we collect all serializable lambdas. Unlike scalac, our backend does not do any inlining currently so our implementation is more straightfoward than theirs. 2. In `addLambdaDeserialize`, we implement the actual $deserializeLambda$ method, the implementation here is directly copied from scalac, it's complex because it needs to work around a limitation of bootstrap methods (they cannot take more than 251 arguments). Since some of this code comes from scalac, this is: Co-Authored-By: Jason Zaugg <[email protected]> Co-Authored-By: Lukas Rytz <[email protected]>
1 parent ddde594 commit e9971c4

File tree

7 files changed

+447
-7
lines changed

7 files changed

+447
-7
lines changed

compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
706706
def isJavaEntryPoint: Boolean = CollectEntryPoints.isJavaEntryPoint(sym)
707707

708708
def isClassConstructor: Boolean = toDenot(sym).isClassConstructor
709+
def isSerializable: Boolean = toDenot(sym).isSerializable
709710

710711
/**
711712
* True for module classes of modules that are top-level or owned only by objects. Module classes
@@ -855,6 +856,9 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
855856

856857
def samMethod(): Symbol =
857858
toDenot(sym).info.abstractTermMembers.headOption.getOrElse(toDenot(sym).info.member(nme.apply)).symbol
859+
860+
def isFunctionClass: Boolean =
861+
defn.isFunctionClass(sym)
858862
}
859863

860864

compiler/src/dotty/tools/backend/jvm/GenBCode.scala

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import dotty.tools.dotc.ast.tpd
66
import dotty.tools.dotc.core.Phases.Phase
77

88
import scala.collection.mutable
9+
import scala.collection.JavaConverters._
910
import scala.tools.asm.CustomAttr
1011
import scala.tools.nsc.backend.jvm._
1112
import dotty.tools.dotc.transform.SymUtils._
@@ -23,6 +24,7 @@ import java.io.DataOutputStream
2324

2425

2526
import scala.tools.asm
27+
import scala.tools.asm.Handle
2628
import scala.tools.asm.tree._
2729
import tpd._
2830
import StdNames._
@@ -304,10 +306,98 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
304306
class Worker2 {
305307
// lazy val localOpt = new LocalOpt(new Settings())
306308

307-
def localOptimizations(classNode: ClassNode): Unit = {
309+
private def localOptimizations(classNode: ClassNode): Unit = {
308310
// BackendStats.timed(BackendStats.methodOptTimer)(localOpt.methodOptimizations(classNode))
309311
}
310312

313+
314+
/* Return an array of all serializable lambdas in this class */
315+
private def collectSerializableLambdas(classNode: ClassNode): Array[Handle] = {
316+
val indyLambdaBodyMethods = new mutable.ArrayBuffer[Handle]
317+
for (m <- classNode.methods.asScala) {
318+
val iter = m.instructions.iterator
319+
while (iter.hasNext) {
320+
val insn = iter.next()
321+
insn match {
322+
case indy: InvokeDynamicInsnNode
323+
// No need to check the exact bsmArgs because we only generate
324+
// altMetafactory indy calls for serializable lambdas.
325+
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
326+
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
327+
indyLambdaBodyMethods += implMethod
328+
case _ =>
329+
}
330+
}
331+
}
332+
indyLambdaBodyMethods.toArray
333+
}
334+
335+
/*
336+
* Add:
337+
*
338+
* private static Object $deserializeLambda$(SerializedLambda l) {
339+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$0](l)
340+
* catch {
341+
* case i: IllegalArgumentException =>
342+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$1](l)
343+
* catch {
344+
* case i: IllegalArgumentException =>
345+
* ...
346+
* return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup${NUM_GROUPS-1}](l)
347+
* }
348+
*
349+
* We use invokedynamic here to enable caching within the deserializer without needing to
350+
* host a static field in the enclosing class. This allows us to add this method to interfaces
351+
* that define lambdas in default methods.
352+
*
353+
* SI-10232 we can't pass arbitrary number of method handles to the final varargs parameter of the bootstrap
354+
* method due to a limitation in the JVM. Instead, we emit a separate invokedynamic bytecode for each group of target
355+
* methods.
356+
*/
357+
private def addLambdaDeserialize(classNode: ClassNode, implMethodsArray: Array[Handle]): Unit = {
358+
import asm.Opcodes._
359+
import BCodeBodyBuilder._
360+
import bTypes._
361+
import coreBTypes._
362+
363+
val cw = classNode
364+
365+
// Make sure to reference the ClassBTypes of all types that are used in the code generated
366+
// here (e.g. java/util/Map) are initialized. Initializing a ClassBType adds it to
367+
// `classBTypeFromInternalNameMap`. When writing the classfile, the asm ClassWriter computes
368+
// stack map frames and invokes the `getCommonSuperClass` method. This method expects all
369+
// ClassBTypes mentioned in the source code to exist in the map.
370+
371+
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectReference).descriptor
372+
373+
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
374+
def emitLambdaDeserializeIndy(targetMethods: Seq[Handle]): Unit = {
375+
mv.visitVarInsn(ALOAD, 0)
376+
mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, targetMethods: _*)
377+
}
378+
379+
val targetMethodGroupLimit = 255 - 1 - 3 // JVM limit. See See MAX_MH_ARITY in CallSite.java
380+
val groups: Array[Array[Handle]] = implMethodsArray.grouped(targetMethodGroupLimit).toArray
381+
val numGroups = groups.length
382+
383+
import scala.tools.asm.Label
384+
val initialLabels = Array.fill(numGroups - 1)(new Label())
385+
val terminalLabel = new Label
386+
def nextLabel(i: Int) = if (i == numGroups - 2) terminalLabel else initialLabels(i + 1)
387+
388+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
389+
mv.visitTryCatchBlock(label, nextLabel(i), nextLabel(i), jlIllegalArgExceptionRef.internalName)
390+
}
391+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
392+
mv.visitLabel(label)
393+
emitLambdaDeserializeIndy(groups(i))
394+
mv.visitInsn(ARETURN)
395+
}
396+
mv.visitLabel(terminalLabel)
397+
emitLambdaDeserializeIndy(groups(numGroups - 1))
398+
mv.visitInsn(ARETURN)
399+
}
400+
311401
def run(): Unit = {
312402
while (true) {
313403
val item = q2.poll
@@ -317,7 +407,11 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
317407
}
318408
else {
319409
try {
320-
localOptimizations(item.plain.classNode)
410+
val plainNode = item.plain.classNode
411+
localOptimizations(plainNode)
412+
val serializableLambdas = collectSerializableLambdas(plainNode)
413+
if (serializableLambdas.nonEmpty)
414+
addLambdaDeserialize(plainNode, serializableLambdas)
321415
addToQ3(item)
322416
} catch {
323417
case ex: Throwable =>

tests/run/lambda-serialization.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream, PrintWriter, StringWriter}
2+
import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
3+
4+
class C extends java.io.Serializable {
5+
val fs = List(
6+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
7+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
8+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
9+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
10+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
11+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
12+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
13+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
14+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
15+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
16+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
17+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
18+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
19+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
20+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => ()
21+
)
22+
private def foo(): Unit = {
23+
assert(false, "should not be called!!!")
24+
}
25+
}
26+
27+
trait FakeSam { def apply(): Unit }
28+
29+
object Test {
30+
def main(args: Array[String]): Unit = {
31+
allRealLambdasRoundTrip()
32+
fakeLambdaFailsToDeserialize()
33+
}
34+
35+
def allRealLambdasRoundTrip(): Unit = {
36+
new C().fs.map(x => serializeDeserialize(x).apply())
37+
}
38+
39+
def fakeLambdaFailsToDeserialize(): Unit = {
40+
val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V",
41+
MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C))
42+
try {
43+
serializeDeserialize(fake).asInstanceOf[FakeSam].apply()
44+
assert(false)
45+
} catch {
46+
case ex: Exception =>
47+
val stackTrace = stackTraceString(ex)
48+
assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace)
49+
}
50+
}
51+
52+
def serializeDeserialize[T <: AnyRef](obj: T) = {
53+
val buffer = new ByteArrayOutputStream
54+
val out = new ObjectOutputStream(buffer)
55+
out.writeObject(obj)
56+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
57+
in.readObject.asInstanceOf[T]
58+
}
59+
60+
def stackTraceString(ex: Throwable): String = {
61+
val writer = new StringWriter
62+
ex.printStackTrace(new PrintWriter(writer))
63+
writer.toString
64+
}
65+
}
66+

0 commit comments

Comments
 (0)