Skip to content

Commit 56ca9ab

Browse files
authored
Merge pull request #5837 from dotty-staging/serialize-lambdas
Add support for lambda serialization
2 parents b44adf7 + 5fc2119 commit 56ca9ab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1707
-138
lines changed

compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
706706
def isJavaEntryPoint: Boolean = CollectEntryPoints.isJavaEntryPoint(sym)
707707

708708
def isClassConstructor: Boolean = toDenot(sym).isClassConstructor
709+
def isSerializable: Boolean = toDenot(sym).isSerializable
709710

710711
/**
711712
* True for module classes of modules that are top-level or owned only by objects. Module classes
@@ -855,6 +856,9 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
855856

856857
def samMethod(): Symbol =
857858
toDenot(sym).info.abstractTermMembers.headOption.getOrElse(toDenot(sym).info.member(nme.apply)).symbol
859+
860+
def isFunctionClass: Boolean =
861+
defn.isFunctionClass(sym)
858862
}
859863

860864

compiler/src/dotty/tools/backend/jvm/GenBCode.scala

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import dotty.tools.dotc.ast.tpd
66
import dotty.tools.dotc.core.Phases.Phase
77

88
import scala.collection.mutable
9+
import scala.collection.JavaConverters._
910
import scala.tools.asm.CustomAttr
1011
import scala.tools.nsc.backend.jvm._
1112
import dotty.tools.dotc.transform.SymUtils._
@@ -23,6 +24,7 @@ import java.io.DataOutputStream
2324

2425

2526
import scala.tools.asm
27+
import scala.tools.asm.Handle
2628
import scala.tools.asm.tree._
2729
import tpd._
2830
import StdNames._
@@ -304,10 +306,98 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
304306
class Worker2 {
305307
// lazy val localOpt = new LocalOpt(new Settings())
306308

307-
def localOptimizations(classNode: ClassNode): Unit = {
309+
private def localOptimizations(classNode: ClassNode): Unit = {
308310
// BackendStats.timed(BackendStats.methodOptTimer)(localOpt.methodOptimizations(classNode))
309311
}
310312

313+
314+
/* Return an array of all serializable lambdas in this class */
315+
private def collectSerializableLambdas(classNode: ClassNode): Array[Handle] = {
316+
val indyLambdaBodyMethods = new mutable.ArrayBuffer[Handle]
317+
for (m <- classNode.methods.asScala) {
318+
val iter = m.instructions.iterator
319+
while (iter.hasNext) {
320+
val insn = iter.next()
321+
insn match {
322+
case indy: InvokeDynamicInsnNode
323+
// No need to check the exact bsmArgs because we only generate
324+
// altMetafactory indy calls for serializable lambdas.
325+
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
326+
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
327+
indyLambdaBodyMethods += implMethod
328+
case _ =>
329+
}
330+
}
331+
}
332+
indyLambdaBodyMethods.toArray
333+
}
334+
335+
/*
336+
* Add:
337+
*
338+
* private static Object $deserializeLambda$(SerializedLambda l) {
339+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$0](l)
340+
* catch {
341+
* case i: IllegalArgumentException =>
342+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$1](l)
343+
* catch {
344+
* case i: IllegalArgumentException =>
345+
* ...
346+
* return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup${NUM_GROUPS-1}](l)
347+
* }
348+
*
349+
* We use invokedynamic here to enable caching within the deserializer without needing to
350+
* host a static field in the enclosing class. This allows us to add this method to interfaces
351+
* that define lambdas in default methods.
352+
*
353+
* SI-10232 we can't pass arbitrary number of method handles to the final varargs parameter of the bootstrap
354+
* method due to a limitation in the JVM. Instead, we emit a separate invokedynamic bytecode for each group of target
355+
* methods.
356+
*/
357+
private def addLambdaDeserialize(classNode: ClassNode, implMethodsArray: Array[Handle]): Unit = {
358+
import asm.Opcodes._
359+
import BCodeBodyBuilder._
360+
import bTypes._
361+
import coreBTypes._
362+
363+
val cw = classNode
364+
365+
// Make sure to reference the ClassBTypes of all types that are used in the code generated
366+
// here (e.g. java/util/Map) are initialized. Initializing a ClassBType adds it to
367+
// `classBTypeFromInternalNameMap`. When writing the classfile, the asm ClassWriter computes
368+
// stack map frames and invokes the `getCommonSuperClass` method. This method expects all
369+
// ClassBTypes mentioned in the source code to exist in the map.
370+
371+
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectReference).descriptor
372+
373+
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
374+
def emitLambdaDeserializeIndy(targetMethods: Seq[Handle]): Unit = {
375+
mv.visitVarInsn(ALOAD, 0)
376+
mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, targetMethods: _*)
377+
}
378+
379+
val targetMethodGroupLimit = 255 - 1 - 3 // JVM limit. See See MAX_MH_ARITY in CallSite.java
380+
val groups: Array[Array[Handle]] = implMethodsArray.grouped(targetMethodGroupLimit).toArray
381+
val numGroups = groups.length
382+
383+
import scala.tools.asm.Label
384+
val initialLabels = Array.fill(numGroups - 1)(new Label())
385+
val terminalLabel = new Label
386+
def nextLabel(i: Int) = if (i == numGroups - 2) terminalLabel else initialLabels(i + 1)
387+
388+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
389+
mv.visitTryCatchBlock(label, nextLabel(i), nextLabel(i), jlIllegalArgExceptionRef.internalName)
390+
}
391+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
392+
mv.visitLabel(label)
393+
emitLambdaDeserializeIndy(groups(i))
394+
mv.visitInsn(ARETURN)
395+
}
396+
mv.visitLabel(terminalLabel)
397+
emitLambdaDeserializeIndy(groups(numGroups - 1))
398+
mv.visitInsn(ARETURN)
399+
}
400+
311401
def run(): Unit = {
312402
while (true) {
313403
val item = q2.poll
@@ -317,7 +407,11 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
317407
}
318408
else {
319409
try {
320-
localOptimizations(item.plain.classNode)
410+
val plainNode = item.plain.classNode
411+
localOptimizations(plainNode)
412+
val serializableLambdas = collectSerializableLambdas(plainNode)
413+
if (serializableLambdas.nonEmpty)
414+
addLambdaDeserialize(plainNode, serializableLambdas)
321415
addToQ3(item)
322416
} catch {
323417
case ex: Throwable =>

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ object desugar {
341341
case _ => false
342342
}
343343

344-
val isCaseClass = mods.is(Case) && !mods.is(Module)
345-
val isCaseObject = mods.is(Case) && mods.is(Module)
344+
val isObject = mods.is(Module)
345+
val isCaseClass = mods.is(Case) && !isObject
346+
val isCaseObject = mods.is(Case) && isObject
346347
val isImplicit = mods.is(Implicit)
347348
val isInstance = isImplicit && mods.mods.exists(_.isInstanceOf[Mod.Instance])
348349
val isEnum = mods.isEnumClass && !mods.is(Module)
@@ -527,13 +528,13 @@ object desugar {
527528
else Nil
528529
}
529530

530-
// Case classes and case objects get Product parents
531-
// Enum cases get an inferred parent if no parents are given
532531
var parents1 = parents
533532
if (isEnumCase && parents.isEmpty)
534533
parents1 = enumClassTypeRef :: Nil
535534
if (isCaseClass | isCaseObject)
536535
parents1 = parents1 :+ scalaDot(str.Product.toTypeName) :+ scalaDot(nme.Serializable.toTypeName)
536+
else if (isObject)
537+
parents1 = parents1 :+ scalaDot(nme.Serializable.toTypeName)
537538
if (isEnum)
538539
parents1 = parents1 :+ ref(defn.EnumType)
539540

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ class Definitions {
10011001
tp.derivesFrom(NothingClass) || tp.derivesFrom(NullClass)
10021002

10031003
/** Is a function class.
1004+
* - FunctionXXL
10041005
* - FunctionN for N >= 0
10051006
* - ImplicitFunctionN for N >= 0
10061007
* - ErasedFunctionN for N > 0
@@ -1020,15 +1021,21 @@ class Definitions {
10201021
*/
10211022
def isErasedFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isErasedFunction
10221023

1023-
/** Is a class that will be erased to FunctionXXL
1024+
/** Is either FunctionXXL or a class that will be erased to FunctionXXL
1025+
* - FunctionXXL
10241026
* - FunctionN for N >= 22
10251027
* - ImplicitFunctionN for N >= 22
10261028
*/
1027-
def isXXLFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).functionArity > MaxImplementedFunctionArity
1029+
def isXXLFunctionClass(cls: Symbol): Boolean = {
1030+
val name = scalaClassName(cls)
1031+
(name eq tpnme.FunctionXXL) || name.functionArity > MaxImplementedFunctionArity
1032+
}
10281033

10291034
/** Is a synthetic function class
10301035
* - FunctionN for N > 22
1031-
* - ImplicitFunctionN for N > 0
1036+
* - ImplicitFunctionN for N >= 0
1037+
* - ErasedFunctionN for N > 0
1038+
* - ErasedImplicitFunctionN for N > 0
10321039
*/
10331040
def isSyntheticFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isSyntheticFunction
10341041

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ object NameOps {
183183
if (n == 0) -1 else n
184184
}
185185

186-
/** Is a function name, i.e one of FunctionN, ImplicitFunctionN for N >= 0 or ErasedFunctionN, ErasedImplicitFunctionN for N > 0
186+
/** Is a function name, i.e one of FunctionXXL, FunctionN, ImplicitFunctionN for N >= 0 or ErasedFunctionN, ErasedImplicitFunctionN for N > 0
187187
*/
188-
def isFunction: Boolean = functionArity >= 0
188+
def isFunction: Boolean = (name eq tpnme.FunctionXXL) || functionArity >= 0
189189

190190
/** Is an implicit function name, i.e one of ImplicitFunctionN for N >= 0 or ErasedImplicitFunctionN for N > 0
191191
*/

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ object StdNames {
205205
final val Singleton: N = "Singleton"
206206
final val Throwable: N = "Throwable"
207207
final val IOOBException: N = "IndexOutOfBoundsException"
208+
final val FunctionXXL: N = "FunctionXXL"
208209

209210
final val ClassfileAnnotation: N = "ClassfileAnnotation"
210211
final val ClassManifest: N = "ClassManifest"

compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,60 @@ object GenericSignatures {
123123
builder.append('L').append(name)
124124
}
125125

126+
def classSig(sym: Symbol, pre: Type = NoType, args: List[Type] = Nil): Unit = {
127+
def argSig(tp: Type): Unit =
128+
tp match {
129+
case bounds: TypeBounds =>
130+
if (!(defn.AnyType <:< bounds.hi)) {
131+
builder.append('+')
132+
boxedSig(bounds.hi)
133+
}
134+
else if (!(bounds.lo <:< defn.NothingType)) {
135+
builder.append('-')
136+
boxedSig(bounds.lo)
137+
}
138+
else builder.append('*')
139+
case PolyType(_, res) =>
140+
builder.append('*') // scala/bug#7932
141+
case _: HKTypeLambda =>
142+
fullNameInSig(tp.typeSymbol)
143+
builder.append(';')
144+
case _ =>
145+
boxedSig(tp)
146+
}
147+
148+
if (pre.exists) {
149+
val preRebound = pre.baseType(sym.owner) // #2585
150+
if (needsJavaSig(preRebound, Nil)) {
151+
val i = builder.length()
152+
jsig(preRebound)
153+
if (builder.charAt(i) == 'L') {
154+
builder.delete(builder.length() - 1, builder.length())// delete ';'
155+
// If the prefix is a module, drop the '$'. Classes (or modules) nested in modules
156+
// are separated by a single '$' in the filename: `object o { object i }` is o$i$.
157+
if (preRebound.typeSymbol.is(ModuleClass))
158+
builder.delete(builder.length() - 1, builder.length())
159+
160+
// Ensure every '.' in the generated signature immediately follows
161+
// a close angle bracket '>'. Any which do not are replaced with '$'.
162+
// This arises due to multiply nested classes in the face of the
163+
// rewriting explained at rebindInnerClass.
164+
165+
// TODO revisit this. Does it align with javac for code that can be expressed in both languages?
166+
val delimiter = if (builder.charAt(builder.length() - 1) == '>') '.' else '$'
167+
builder.append(delimiter).append(sanitizeName(sym.name.asSimpleName))
168+
} else fullNameInSig(sym)
169+
} else fullNameInSig(sym)
170+
} else fullNameInSig(sym)
171+
172+
if (args.nonEmpty) {
173+
builder.append('<')
174+
args foreach argSig
175+
builder.append('>')
176+
}
177+
builder.append(';')
178+
}
179+
126180
@noinline
127181
def jsig(tp0: Type, toplevel: Boolean = false, primitiveOK: Boolean = true): Unit = {
128182

@@ -133,57 +187,6 @@ object GenericSignatures {
133187
typeParamSig(ref.paramName.lastPart)
134188

135189
case RefOrAppliedType(sym, pre, args) =>
136-
def argSig(tp: Type): Unit =
137-
tp match {
138-
case bounds: TypeBounds =>
139-
if (!(defn.AnyType <:< bounds.hi)) {
140-
builder.append('+')
141-
boxedSig(bounds.hi)
142-
}
143-
else if (!(bounds.lo <:< defn.NothingType)) {
144-
builder.append('-')
145-
boxedSig(bounds.lo)
146-
}
147-
else builder.append('*')
148-
case PolyType(_, res) =>
149-
builder.append('*') // scala/bug#7932
150-
case _: HKTypeLambda =>
151-
fullNameInSig(tp.typeSymbol)
152-
builder.append(';')
153-
case _ =>
154-
boxedSig(tp)
155-
}
156-
def classSig: Unit = {
157-
val preRebound = pre.baseType(sym.owner) // #2585
158-
if (needsJavaSig(preRebound, Nil)) {
159-
val i = builder.length()
160-
jsig(preRebound)
161-
if (builder.charAt(i) == 'L') {
162-
builder.delete(builder.length() - 1, builder.length())// delete ';'
163-
// If the prefix is a module, drop the '$'. Classes (or modules) nested in modules
164-
// are separated by a single '$' in the filename: `object o { object i }` is o$i$.
165-
if (preRebound.typeSymbol.is(ModuleClass))
166-
builder.delete(builder.length() - 1, builder.length())
167-
168-
// Ensure every '.' in the generated signature immediately follows
169-
// a close angle bracket '>'. Any which do not are replaced with '$'.
170-
// This arises due to multiply nested classes in the face of the
171-
// rewriting explained at rebindInnerClass.
172-
173-
// TODO revisit this. Does it align with javac for code that can be expressed in both languages?
174-
val delimiter = if (builder.charAt(builder.length() - 1) == '>') '.' else '$'
175-
builder.append(delimiter).append(sanitizeName(sym.name.asSimpleName))
176-
} else fullNameInSig(sym)
177-
} else fullNameInSig(sym)
178-
179-
if (args.nonEmpty) {
180-
builder.append('<')
181-
args foreach argSig
182-
builder.append('>')
183-
}
184-
builder.append(';')
185-
}
186-
187190
// If args isEmpty, Array is being used as a type constructor
188191
if (sym == defn.ArrayClass && args.nonEmpty) {
189192
if (unboundedGenericArrayLevel(tp) == 1) jsig(defn.ObjectType)
@@ -215,14 +218,14 @@ object GenericSignatures {
215218
val unboxed = ValueClasses.valueClassUnbox(sym.asClass).info.finalResultType
216219
val unboxedSeen = tp.memberInfo(ValueClasses.valueClassUnbox(sym.asClass)).finalResultType
217220
if (unboxedSeen.isPrimitiveValueType && !primitiveOK)
218-
classSig
221+
classSig(sym, pre, args)
219222
else
220223
jsig(unboxedSeen, toplevel, primitiveOK)
221224
}
222225
else if (defn.isXXLFunctionClass(sym))
223-
jsig(defn.FunctionXXLType, toplevel, primitiveOK)
226+
classSig(defn.FunctionXXLClass)
224227
else if (sym.isClass)
225-
classSig
228+
classSig(sym, pre, args)
226229
else
227230
jsig(erasure(tp), toplevel, primitiveOK)
228231

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ class Typer extends Namer
361361
if (untpd.isVarPattern(tree) && name.isTermName)
362362
return typed(desugar.patternVar(tree), pt)
363363
}
364+
// Shortcut for the root package, this is not just a performance
365+
// optimization, it also avoids forcing imports thus potentially avoiding
366+
// cyclic references.
367+
if (name == nme.ROOTPKG)
368+
return tree.withType(defn.RootPackage.termRef)
364369

365370
val rawType = {
366371
val saved1 = unimported
@@ -1606,6 +1611,11 @@ class Typer extends Namer
16061611
var result = if (isTreeType(tree)) typedType(tree)(superCtx) else typedExpr(tree)(superCtx)
16071612
val psym = result.tpe.dealias.typeSymbol
16081613
if (seenParents.contains(psym) && !cls.isRefinementClass) {
1614+
// Desugaring can adds parents to classes, but we don't want to emit an
1615+
// error if the same parent was explicitly added in user code.
1616+
if (!tree.span.isSourceDerived)
1617+
return EmptyTree
1618+
16091619
if (!ctx.isAfterTyper) ctx.error(i"$psym is extended twice", tree.sourcePos)
16101620
}
16111621
else seenParents += psym
@@ -1640,7 +1650,7 @@ class Typer extends Namer
16401650

16411651
completeAnnotations(cdef, cls)
16421652
val constr1 = typed(constr).asInstanceOf[DefDef]
1643-
val parentsWithClass = ensureFirstTreeIsClass(parents mapconserve typedParent, cdef.nameSpan)
1653+
val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan)
16441654
val parents1 = ensureConstrCall(cls, parentsWithClass)(superCtx)
16451655

16461656
var self1 = typed(self)(ctx.outer).asInstanceOf[ValDef] // outer context where class members are not visible

0 commit comments

Comments
 (0)