Skip to content

Commit baa05cd

Browse files
smarterretronymlrytz
committed
Fix #4442: Add support for lambda serialization
In Scala, lambdas whose SAM extend Serializable as well as lambdas whose SAM is simply scala.Function* should be serializable but this was not the case in Dotty so far. On the JVM, lambdas instantiated using invokedynamic calls require some special handling to be serializable: 1. We need to use the invokedynamic bootstrap method `LambdaMetaFactory#altMetafactory` instead of `LambdaMetaFactory#metafactory`, this allows us to pass the FLAG_SERIALIZABLE flag. This is implemented in the backend submodule commit included in this commit (see lampepfl/scala#39). 2. In the enclosing class where the lambda is defined, a $deserializeLambda$ method needs to be generated, this is implemented in this commit. Most of the logic for $deserializeLambda$ is implemented in the Scala 2.12 standard libraries class scala.runtime.LambdaDeserialize and scala.runtime.LambdaDeserializer which can be used here as-is, the only logic we actually need to implement here is: 1. In `collectSerializableLambdas`, we collect all serializable lambdas. Unlike scalac, our backend does not do any inlining currently so our implementation is more straightfoward than theirs. 2. In `addLambdaDeserialize`, we implement the actual $deserializeLambda$ method, the implementation here is directly copied from scalac, it's complex because it needs to work around a limitation of bootstrap methods (they cannot take more than 251 arguments). Since some of this code comes from scalac, this is: Co-Authored-By: Jason Zaugg <[email protected]> Co-Authored-By: Lukas Rytz <[email protected]>
1 parent bf24325 commit baa05cd

13 files changed

+1421
-7
lines changed

compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
706706
def isJavaEntryPoint: Boolean = CollectEntryPoints.isJavaEntryPoint(sym)
707707

708708
def isClassConstructor: Boolean = toDenot(sym).isClassConstructor
709+
def isSerializable: Boolean = toDenot(sym).isSerializable
709710

710711
/**
711712
* True for module classes of modules that are top-level or owned only by objects. Module classes
@@ -855,6 +856,9 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
855856

856857
def samMethod(): Symbol =
857858
toDenot(sym).info.abstractTermMembers.headOption.getOrElse(toDenot(sym).info.member(nme.apply)).symbol
859+
860+
def isFunctionClass: Boolean =
861+
defn.isFunctionClass(sym)
858862
}
859863

860864

compiler/src/dotty/tools/backend/jvm/GenBCode.scala

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import dotty.tools.dotc.ast.tpd
66
import dotty.tools.dotc.core.Phases.Phase
77

88
import scala.collection.mutable
9+
import scala.collection.JavaConverters._
910
import scala.tools.asm.CustomAttr
1011
import scala.tools.nsc.backend.jvm._
1112
import dotty.tools.dotc.transform.SymUtils._
@@ -23,6 +24,7 @@ import java.io.DataOutputStream
2324

2425

2526
import scala.tools.asm
27+
import scala.tools.asm.Handle
2628
import scala.tools.asm.tree._
2729
import tpd._
2830
import StdNames._
@@ -304,10 +306,98 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
304306
class Worker2 {
305307
// lazy val localOpt = new LocalOpt(new Settings())
306308

307-
def localOptimizations(classNode: ClassNode): Unit = {
309+
private def localOptimizations(classNode: ClassNode): Unit = {
308310
// BackendStats.timed(BackendStats.methodOptTimer)(localOpt.methodOptimizations(classNode))
309311
}
310312

313+
314+
/* Return an array of all serializable lambdas in this class */
315+
private def collectSerializableLambdas(classNode: ClassNode): Array[Handle] = {
316+
val indyLambdaBodyMethods = new mutable.ArrayBuffer[Handle]
317+
for (m <- classNode.methods.asScala) {
318+
val iter = m.instructions.iterator
319+
while (iter.hasNext) {
320+
val insn = iter.next()
321+
insn match {
322+
case indy: InvokeDynamicInsnNode
323+
// No need to check the exact bsmArgs because we only generate
324+
// altMetafactory indy calls for serializable lambdas.
325+
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
326+
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
327+
indyLambdaBodyMethods += implMethod
328+
case _ =>
329+
}
330+
}
331+
}
332+
indyLambdaBodyMethods.toArray
333+
}
334+
335+
/*
336+
* Add:
337+
*
338+
* private static Object $deserializeLambda$(SerializedLambda l) {
339+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$0](l)
340+
* catch {
341+
* case i: IllegalArgumentException =>
342+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$1](l)
343+
* catch {
344+
* case i: IllegalArgumentException =>
345+
* ...
346+
* return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup${NUM_GROUPS-1}](l)
347+
* }
348+
*
349+
* We use invokedynamic here to enable caching within the deserializer without needing to
350+
* host a static field in the enclosing class. This allows us to add this method to interfaces
351+
* that define lambdas in default methods.
352+
*
353+
* SI-10232 we can't pass arbitrary number of method handles to the final varargs parameter of the bootstrap
354+
* method due to a limitation in the JVM. Instead, we emit a separate invokedynamic bytecode for each group of target
355+
* methods.
356+
*/
357+
private def addLambdaDeserialize(classNode: ClassNode, implMethodsArray: Array[Handle]): Unit = {
358+
import asm.Opcodes._
359+
import BCodeBodyBuilder._
360+
import bTypes._
361+
import coreBTypes._
362+
363+
val cw = classNode
364+
365+
// Make sure to reference the ClassBTypes of all types that are used in the code generated
366+
// here (e.g. java/util/Map) are initialized. Initializing a ClassBType adds it to
367+
// `classBTypeFromInternalNameMap`. When writing the classfile, the asm ClassWriter computes
368+
// stack map frames and invokes the `getCommonSuperClass` method. This method expects all
369+
// ClassBTypes mentioned in the source code to exist in the map.
370+
371+
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectReference).descriptor
372+
373+
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
374+
def emitLambdaDeserializeIndy(targetMethods: Seq[Handle]): Unit = {
375+
mv.visitVarInsn(ALOAD, 0)
376+
mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, targetMethods: _*)
377+
}
378+
379+
val targetMethodGroupLimit = 255 - 1 - 3 // JVM limit. See See MAX_MH_ARITY in CallSite.java
380+
val groups: Array[Array[Handle]] = implMethodsArray.grouped(targetMethodGroupLimit).toArray
381+
val numGroups = groups.length
382+
383+
import scala.tools.asm.Label
384+
val initialLabels = Array.fill(numGroups - 1)(new Label())
385+
val terminalLabel = new Label
386+
def nextLabel(i: Int) = if (i == numGroups - 2) terminalLabel else initialLabels(i + 1)
387+
388+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
389+
mv.visitTryCatchBlock(label, nextLabel(i), nextLabel(i), jlIllegalArgExceptionRef.internalName)
390+
}
391+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
392+
mv.visitLabel(label)
393+
emitLambdaDeserializeIndy(groups(i))
394+
mv.visitInsn(ARETURN)
395+
}
396+
mv.visitLabel(terminalLabel)
397+
emitLambdaDeserializeIndy(groups(numGroups - 1))
398+
mv.visitInsn(ARETURN)
399+
}
400+
311401
def run(): Unit = {
312402
while (true) {
313403
val item = q2.poll
@@ -317,7 +407,11 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
317407
}
318408
else {
319409
try {
320-
localOptimizations(item.plain.classNode)
410+
val plainNode = item.plain.classNode
411+
localOptimizations(plainNode)
412+
val serializableLambdas = collectSerializableLambdas(plainNode)
413+
if (serializableLambdas.nonEmpty)
414+
addLambdaDeserialize(plainNode, serializableLambdas)
321415
addToQ3(item)
322416
} catch {
323417
case ex: Throwable =>
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class C { inline final def f: Int => Int = (x: Int) => x + 1 }
2+
3+
object Test extends App {
4+
import java.io._
5+
6+
def serialize(obj: AnyRef): Array[Byte] = {
7+
val buffer = new ByteArrayOutputStream
8+
val out = new ObjectOutputStream(buffer)
9+
out.writeObject(obj)
10+
buffer.toByteArray
11+
}
12+
def deserialize(a: Array[Byte]): AnyRef = {
13+
val in = new ObjectInputStream(new ByteArrayInputStream(a))
14+
in.readObject
15+
}
16+
17+
def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]
18+
19+
assert(serializeDeserialize((new C).f).isInstanceOf[Function1[_, _]])
20+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import java.io._
2+
3+
import java.net.URLClassLoader
4+
5+
class C {
6+
def serializeDeserialize[T <: AnyRef](obj: T) = {
7+
val buffer = new ByteArrayOutputStream
8+
val out = new ObjectOutputStream(buffer)
9+
out.writeObject(obj)
10+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
11+
in.readObject.asInstanceOf[T]
12+
}
13+
14+
serializeDeserialize((c: String) => c.length)
15+
}
16+
17+
object Test {
18+
def main(args: Array[String]): Unit = {
19+
test()
20+
}
21+
22+
def test(): Unit = {
23+
val loader = getClass.getClassLoader.asInstanceOf[URLClassLoader]
24+
val loaderCClass = classOf[C]
25+
def deserializedInThrowawayClassloader = {
26+
val throwawayLoader: java.net.URLClassLoader = new java.net.URLClassLoader(loader.getURLs, ClassLoader.getSystemClassLoader) {
27+
val maxMemory = Runtime.getRuntime.maxMemory()
28+
val junk = new Array[Long]((maxMemory / 8 / 4).toInt)
29+
}
30+
val clazz = throwawayLoader.loadClass("C")
31+
assert(clazz != loaderCClass)
32+
clazz.newInstance()
33+
}
34+
(1 to 5) foreach { i =>
35+
// This would OOM by the fifth iteration if we leaked `throwawayLoader` during
36+
// deserialization.
37+
deserializedInThrowawayClassloader
38+
}
39+
}
40+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
2+
3+
trait IntToString extends java.io.Serializable { def apply(i: Int): String }
4+
5+
object Test {
6+
def main(args: Array[String]): Unit = {
7+
roundTrip()
8+
roundTripIndySam()
9+
}
10+
11+
def roundTrip(): Unit = {
12+
val c = new Capture("Capture")
13+
val lambda = (p: Param) => ("a", p, c)
14+
val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
15+
val p = new Param
16+
assert(reconstituted1.apply(p) == ("a", p, c))
17+
val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
18+
assert(reconstituted1.getClass == reconstituted2.getClass)
19+
20+
val reconstituted3 = serializeDeserialize(reconstituted1)
21+
assert(reconstituted3.apply(p) == ("a", p, c))
22+
23+
val specializedLambda = (p: Int) => List(p, c).length
24+
assert(serializeDeserialize(specializedLambda).apply(42) == 2)
25+
assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
26+
}
27+
28+
// lambda targeting a SAM, not a FunctionN (should behave the same way)
29+
def roundTripIndySam(): Unit = {
30+
val lambda: IntToString = (x: Int) => "yo!" * x
31+
val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
32+
val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
33+
assert(reconstituted1.apply(2) == "yo!yo!")
34+
assert(reconstituted1.getClass == reconstituted2.getClass)
35+
}
36+
37+
def serializeDeserialize[T <: AnyRef](obj: T) = {
38+
val buffer = new ByteArrayOutputStream
39+
val out = new ObjectOutputStream(buffer)
40+
out.writeObject(obj)
41+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
42+
in.readObject.asInstanceOf[T]
43+
}
44+
}
45+
46+
case class Capture(s: String) extends Serializable
47+
class Param
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream, PrintWriter, StringWriter}
2+
import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
3+
4+
class C {
5+
val f1 = ((x1: Int,
6+
x2: String,
7+
x3: Int,
8+
x4: Int,
9+
x5: Int,
10+
x6: Int,
11+
x7: Int,
12+
x8: Int,
13+
x9: Int,
14+
x10: Int,
15+
x11: Int,
16+
x12: Int,
17+
x13: Int,
18+
x14: Int,
19+
x15: Int,
20+
x16: Int,
21+
x17: Int,
22+
x18: Int,
23+
x19: Int,
24+
x20: Int,
25+
x21: Int,
26+
x22: Int,
27+
x23: Int,
28+
x24: Int,
29+
x25: Int,
30+
x26: Int) => x2 + x1)
31+
}
32+
33+
object Test {
34+
def main(args: Array[String]): Unit = {
35+
val c = new C
36+
serializeDeserialize(c.f1)
37+
}
38+
39+
def serializeDeserialize[T <: AnyRef](obj: T) = {
40+
val buffer = new ByteArrayOutputStream
41+
val out = new ObjectOutputStream(buffer)
42+
out.writeObject(obj)
43+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
44+
in.readObject.asInstanceOf[T]
45+
}
46+
}
47+

tests/run/lambda-serialization.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream, PrintWriter, StringWriter}
2+
import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
3+
4+
class C extends java.io.Serializable {
5+
val fs = List(
6+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
7+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
8+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
9+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
10+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
11+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
12+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
13+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
14+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
15+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
16+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
17+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
18+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
19+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
20+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => ()
21+
)
22+
private def foo(): Unit = {
23+
assert(false, "should not be called!!!")
24+
}
25+
}
26+
27+
trait FakeSam { def apply(): Unit }
28+
29+
object Test {
30+
def main(args: Array[String]): Unit = {
31+
allRealLambdasRoundTrip()
32+
fakeLambdaFailsToDeserialize()
33+
}
34+
35+
def allRealLambdasRoundTrip(): Unit = {
36+
new C().fs.map(x => serializeDeserialize(x).apply())
37+
}
38+
39+
def fakeLambdaFailsToDeserialize(): Unit = {
40+
val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V",
41+
MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C))
42+
try {
43+
serializeDeserialize(fake).asInstanceOf[FakeSam].apply()
44+
assert(false)
45+
} catch {
46+
case ex: Exception =>
47+
val stackTrace = stackTraceString(ex)
48+
assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace)
49+
}
50+
}
51+
52+
def serializeDeserialize[T <: AnyRef](obj: T) = {
53+
val buffer = new ByteArrayOutputStream
54+
val out = new ObjectOutputStream(buffer)
55+
out.writeObject(obj)
56+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
57+
in.readObject.asInstanceOf[T]
58+
}
59+
60+
def stackTraceString(ex: Throwable): String = {
61+
val writer = new StringWriter
62+
ex.printStackTrace(new PrintWriter(writer))
63+
writer.toString
64+
}
65+
}
66+

0 commit comments

Comments
 (0)