Skip to content

Commit 5dbd4ed

Browse files
authored
Merge pull request #9612 from dotty-staging/topic/enum-serialization-alt
implement readResolve in terms of fromOrdinalDollar method
2 parents 5d0eea3 + 89d107c commit 5dbd4ed

File tree

7 files changed

+137
-42
lines changed

7 files changed

+137
-42
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+27-12
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ object desugar {
476476
val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase)
477477
if (enumCases.isEmpty)
478478
report.error(EnumerationsShouldNotBeEmpty(cdef), namePos)
479+
else
480+
enumCases.last.pushAttachment(DesugarEnums.DefinesEnumLookupMethods, ())
479481
val enumCompanionRef = TermRefTree()
480482
val enumImport =
481483
Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_)))
@@ -568,7 +570,7 @@ object desugar {
568570
// Note: copy default parameters need @uncheckedVariance; see
569571
// neg/t1843-variances.scala for a test case. The test would give
570572
// two errors without @uncheckedVariance, one of them spurious.
571-
val caseClassMeths = {
573+
val (caseClassMeths, enumScaffolding) = {
572574
def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) =
573575
DefDef(name, Nil, Nil, tpt, rhs).withMods(synthetic)
574576

@@ -586,9 +588,11 @@ object desugar {
586588
yield syntheticProperty(selName, caseParams(i).tpt,
587589
Select(This(EmptyTypeIdent), caseParams(i).name))
588590

589-
def enumMeths =
590-
if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil
591-
else Nil
591+
def enumCaseMeths =
592+
if isEnumCase then
593+
val (ordinal, scaffolding) = nextOrdinal(className, CaseKind.Class, definesEnumLookupMethods(cdef))
594+
(ordinalMethLit(ordinal) :: enumLabelLit(className.toString) :: Nil, scaffolding)
595+
else (Nil, Nil)
592596
def copyMeths = {
593597
val hasRepeatedParam = constrVparamss.exists(_.exists {
594598
case ValDef(_, tpt, _) => isRepeated(tpt)
@@ -607,8 +611,9 @@ object desugar {
607611
}
608612

609613
if (isCaseClass)
610-
copyMeths ::: enumMeths ::: productElemMeths
611-
else Nil
614+
val (enumMeths, enumScaffolding) = enumCaseMeths
615+
(copyMeths ::: enumMeths ::: productElemMeths, enumScaffolding)
616+
else (Nil, Nil)
612617
}
613618

614619
var parents1 = parents
@@ -809,7 +814,7 @@ object desugar {
809814
case _ =>
810815
}
811816

812-
flatTree(cdef1 :: companions ::: implicitWrappers)
817+
flatTree(cdef1 :: companions ::: implicitWrappers ::: enumScaffolding)
813818
}.reporting(i"desugared: $result", Printers.desugar)
814819

815820
/** Expand
@@ -862,7 +867,7 @@ object desugar {
862867
else if (isEnumCase) {
863868
typeParamIsReferenced(enumClass.typeParams, Nil, Nil, impl.parents)
864869
// used to check there are no illegal references to enum's type parameters in parents
865-
expandEnumModule(moduleName, impl, mods, mdef.span)
870+
expandEnumModule(moduleName, impl, mods, definesEnumLookupMethods(mdef), mdef.span)
866871
}
867872
else {
868873
val clsName = moduleName.moduleClassName
@@ -990,6 +995,12 @@ object desugar {
990995

991996
private def inventTypeName(tree: Tree)(using Context): String = typeNameExtractor("", tree)
992997

998+
/**This will check if this def tree is marked to define enum lookup methods,
999+
* this is not recommended to call more than once per tree
1000+
*/
1001+
private def definesEnumLookupMethods(ddef: DefTree): Boolean =
1002+
ddef.removeAttachment(DefinesEnumLookupMethods).isDefined
1003+
9931004
/** val p1, ..., pN: T = E
9941005
* ==>
9951006
* makePatDef[[val p1: T1 = E]]; ...; makePatDef[[val pN: TN = E]]
@@ -1001,11 +1012,15 @@ object desugar {
10011012
def patDef(pdef: PatDef)(using Context): Tree = flatTree {
10021013
val PatDef(mods, pats, tpt, rhs) = pdef
10031014
if (mods.isEnumCase)
1004-
pats map {
1005-
case id: Ident =>
1006-
expandSimpleEnumCase(id.name.asTermName, mods,
1015+
def expand(id: Ident, definesLookups: Boolean) =
1016+
expandSimpleEnumCase(id.name.asTermName, mods, definesLookups,
10071017
Span(id.span.start, id.span.end, id.span.start))
1008-
}
1018+
1019+
val ids = pats.asInstanceOf[List[Ident]]
1020+
if definesEnumLookupMethods(pdef) then
1021+
ids.init.map(expand(_, false)) ::: expand(ids.last, true) :: Nil
1022+
else
1023+
ids.map(expand(_, false))
10091024
else {
10101025
val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt))
10111026
pats1 map (makePatDef(pdef, mods, _, rhs))

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

+42-14
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ object DesugarEnums {
2020
val Simple, Object, Class: Value = Value
2121
}
2222

23-
/** Attachment containing the number of enum cases and the smallest kind that was seen so far. */
24-
val EnumCaseCount: Property.Key[(Int, DesugarEnums.CaseKind.Value)] = Property.Key()
23+
/** Attachment containing the number of enum cases, the smallest kind that was seen so far,
24+
* and a list of all the value cases with their ordinals.
25+
*/
26+
val EnumCaseCount: Property.Key[(Int, CaseKind.Value, List[(Int, TermName)])] = Property.Key()
27+
28+
/** Attachment signalling that when this definition is desugared, it should add any additional
29+
* lookup methods for enums.
30+
*/
31+
val DefinesEnumLookupMethods: Property.Key[Unit] = Property.Key()
2532

2633
/** The enumeration class that belongs to an enum case. This works no matter
2734
* whether the case is still in the enum class or it has been transferred to the
@@ -122,6 +129,21 @@ object DesugarEnums {
122129
valueOfDef :: Nil
123130
}
124131

132+
private def enumLookupMethods(cases: List[(Int, TermName)])(using Context): List[Tree] =
133+
if isJavaEnum || cases.isEmpty then Nil
134+
else
135+
val defaultCase =
136+
val ord = Ident(nme.ordinal)
137+
val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil)))
138+
CaseDef(ord, EmptyTree, err)
139+
val valueCases = cases.map((i, name) =>
140+
CaseDef(Literal(Constant(i)), EmptyTree, Ident(name))
141+
) ::: defaultCase :: Nil
142+
val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil),
143+
rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases))
144+
.withFlags(Synthetic | Private)
145+
fromOrdinalDef :: Nil
146+
125147
/** A creation method for a value of enum type `E`, which is defined as follows:
126148
*
127149
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
@@ -256,16 +278,22 @@ object DesugarEnums {
256278
* - scaffolding containing the necessary definitions for singleton enum cases
257279
* unless that scaffolding was already generated by a previous call to `nextEnumKind`.
258280
*/
259-
def nextOrdinal(kind: CaseKind.Value)(using Context): (Int, List[Tree]) = {
260-
val (count, seenKind) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class))
261-
val minKind = if (kind < seenKind) kind else seenKind
262-
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
263-
val scaffolding =
281+
def nextOrdinal(name: Name, kind: CaseKind.Value, definesLookups: Boolean)(using Context): (Int, List[Tree]) = {
282+
val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, Nil))
283+
val minKind = if kind < seenKind then kind else seenKind
284+
val cases = name match
285+
case name: TermName => (ordinal, name) :: seenCases
286+
case _ => seenCases
287+
ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, cases))
288+
val scaffolding0 =
264289
if (kind >= seenKind) Nil
265290
else if (kind == CaseKind.Object) enumScaffolding
266291
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
267292
else enumScaffolding :+ enumValueCreator
268-
(count, scaffolding)
293+
val scaffolding =
294+
if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse)
295+
else scaffolding0
296+
(ordinal, scaffolding)
269297
}
270298

271299
def param(name: TermName, typ: Type)(using Context) =
@@ -286,13 +314,13 @@ object DesugarEnums {
286314
enumLabelMeth(Literal(Constant(name)))
287315

288316
/** Expand a module definition representing a parameterless enum case */
289-
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
317+
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree = {
290318
assert(impl.body.isEmpty)
291319
if (!enumClass.exists) EmptyTree
292320
else if (impl.parents.isEmpty)
293-
expandSimpleEnumCase(name, mods, span)
321+
expandSimpleEnumCase(name, mods, definesLookups, span)
294322
else {
295-
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
323+
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups)
296324
val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
297325
val enumLabelDef = enumLabelLit(name.toString)
298326
val impl1 = cpy.Template(impl)(
@@ -305,15 +333,15 @@ object DesugarEnums {
305333
}
306334

307335
/** Expand a simple enum case */
308-
def expandSimpleEnumCase(name: TermName, mods: Modifiers, span: Span)(using Context): Tree =
336+
def expandSimpleEnumCase(name: TermName, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree =
309337
if (!enumClass.exists) EmptyTree
310338
else if (enumClass.typeParams.nonEmpty) {
311339
val parent = interpolatedEnumParent(span)
312340
val impl = Template(emptyConstructor, parent :: Nil, Nil, EmptyValDef, Nil)
313-
expandEnumModule(name, impl, mods, span)
341+
expandEnumModule(name, impl, mods, definesLookups, span)
314342
}
315343
else {
316-
val (tag, scaffolding) = nextOrdinal(CaseKind.Simple)
344+
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Simple, definesLookups)
317345
val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), Literal(Constant(name.toString))))
318346
val vdef = ValDef(name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue, span))
319347
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/core/StdNames.scala

+2
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ object StdNames {
615615
val using: N = "using"
616616
val value: N = "value"
617617
val valueOf : N = "valueOf"
618+
val fromOrdinalDollar: N = "$fromOrdinal"
618619
val values: N = "values"
619620
val view_ : N = "view"
620621
val wait_ : N = "wait"
@@ -623,6 +624,7 @@ object StdNames {
623624
val WorksheetWrapper: N = "WorksheetWrapper"
624625
val wrap: N = "wrap"
625626
val writeReplace: N = "writeReplace"
627+
val readResolve: N = "readResolve"
626628
val zero: N = "zero"
627629
val zip: N = "zip"
628630
val nothingRuntimeClass: N = "scala.runtime.Nothing$"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

+16-7
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,19 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
373373
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
374374
.exists
375375

376+
private def hasReadResolve(clazz: ClassSymbol)(using Context): Boolean =
377+
clazz.membersNamed(nme.readResolve)
378+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
379+
.exists
380+
376381
private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
377382
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
378383
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
379384

385+
private def readResolveDef(clazz: ClassSymbol)(using Context): TermSymbol =
386+
newSymbol(clazz, nme.readResolve, Method | Private | Synthetic,
387+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
388+
380389
/** If this is a static object `Foo`, add the method:
381390
*
382391
* private def writeReplace(): AnyRef =
@@ -405,22 +414,22 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
405414
/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
406415
* and not deriving from `java.lang.Enum` add the method:
407416
*
408-
* private def writeReplace(): AnyRef =
409-
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
417+
* private def readResolve(): AnyRef =
418+
* MyEnum.$fromOrdinal(this.ordinal)
410419
*
411420
* unless an implementation already exists, otherwise do nothing.
412421
*/
413422
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
414423
if clazz.isEnumValueImplementation
415424
&& !clazz.derivesFrom(defn.JavaEnumClass)
416425
&& clazz.isSerializable
417-
&& !hasWriteReplace(clazz)
426+
&& !hasReadResolve(clazz)
418427
then
419428
List(
420-
DefDef(writeReplaceDef(clazz),
421-
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
422-
defn.EnumValueSerializationProxyConstructor,
423-
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
429+
DefDef(readResolveDef(clazz),
430+
_ => ref(clazz.owner.owner.sourceModule)
431+
.select(nme.fromOrdinalDollar)
432+
.appliedTo(This(clazz).select(nme.ordinal).ensureApplied))
424433
.withSpan(ctx.owner.span.focus))
425434
else
426435
Nil

tests/run/enums-serialization-compat.scala

+35-8
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,49 @@ import java.io._
22
import scala.util.Using
33

44
enum JColor extends java.lang.Enum[JColor]:
5-
case Red
5+
case Red // java enum has magic JVM support
66

77
enum SColor:
8-
case Green
8+
case Green // simple case last
99

1010
enum SColorTagged[T]:
11-
case Blue extends SColorTagged[Unit]
11+
case Blue extends SColorTagged[Unit]
12+
case Rgb(r: Byte, g: Byte, b: Byte) extends SColorTagged[(Byte, Byte, Byte)] // mixing pattern kinds
13+
case Indigo extends SColorTagged[Unit]
14+
case Cmyk(c: Byte, m: Byte, y: Byte, k: Byte) extends SColorTagged[(Byte, Byte, Byte, Byte)] // class case last
15+
16+
enum Nucleobase:
17+
case A,C,G,T // patdef last
18+
19+
enum MyClassTag[T](wrapped: Class[?]):
20+
case IntTag extends MyClassTag[Int](classOf[Int])
21+
case UnitTag extends MyClassTag[Unit](classOf[Unit]) // value case last
22+
23+
extension (ref: AnyRef) def aliases(compare: AnyRef) = assert(ref eq compare, compare)
1224

1325
@main def Test = Using.Manager({ use =>
1426
val buf = use(ByteArrayOutputStream())
1527
val out = use(ObjectOutputStream(buf))
16-
Seq(JColor.Red, SColor.Green, SColorTagged.Blue).foreach(out.writeObject)
28+
Seq(JColor.Red, SColor.Green, SColorTagged.Blue, SColorTagged.Indigo).foreach(out.writeObject)
29+
Seq(Nucleobase.A, Nucleobase.C, Nucleobase.G, Nucleobase.T).foreach(out.writeObject)
30+
Seq(MyClassTag.IntTag, MyClassTag.UnitTag).foreach(out.writeObject)
1731
val read = use(ByteArrayInputStream(buf.toByteArray))
1832
val in = use(ObjectInputStream(read))
19-
val Seq(Red @ _, Green @ _, Blue @ _) = (1 to 3).map(_ => in.readObject)
20-
assert(Red eq JColor.Red, JColor.Red)
21-
assert(Green eq SColor.Green, SColor.Green)
22-
assert(Blue eq SColorTagged.Blue, SColorTagged.Blue)
33+
34+
val Seq(Red @ _, Green @ _, Blue @ _, Indigo @ _) = (1 to 4).map(_ => in.readObject)
35+
Red aliases JColor.Red
36+
Green aliases SColor.Green
37+
Blue aliases SColorTagged.Blue
38+
Indigo aliases SColorTagged.Indigo
39+
40+
val Seq(A @ _, C @ _, G @ _, T @ _) = (1 to 4).map(_ => in.readObject)
41+
A aliases Nucleobase.A
42+
C aliases Nucleobase.C
43+
G aliases Nucleobase.G
44+
T aliases Nucleobase.T
45+
46+
val Seq(IntTag @ _, UnitTag @ _) = (1 to 2).map(_ => in.readObject)
47+
IntTag aliases MyClassTag.IntTag
48+
UnitTag aliases MyClassTag.UnitTag
49+
2350
}).get

tests/semanticdb/metac.expect

+15-1
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ Schema => SemanticDB v4
641641
Uri => Enums.scala
642642
Text => empty
643643
Language => Scala
644-
Symbols => 169 entries
644+
Symbols => 183 entries
645645
Occurrences => 203 entries
646646

647647
Symbols:
@@ -651,6 +651,8 @@ _empty_/Enums.Coin#`<init>`(). => primary ctor <init>
651651
_empty_/Enums.Coin#`<init>`().(value) => param value
652652
_empty_/Enums.Coin#value. => val method value
653653
_empty_/Enums.Coin. => final object Coin
654+
_empty_/Enums.Coin.$fromOrdinal(). => method $fromOrdinal
655+
_empty_/Enums.Coin.$fromOrdinal().(_$ordinal) => param _$ordinal
654656
_empty_/Enums.Coin.$values. => val method $values
655657
_empty_/Enums.Coin.Dime. => case val static enum method Dime
656658
_empty_/Enums.Coin.Dollar. => case val static enum method Dollar
@@ -663,6 +665,8 @@ _empty_/Enums.Coin.values(). => method values
663665
_empty_/Enums.Colour# => abstract sealed enum class Colour
664666
_empty_/Enums.Colour#`<init>`(). => primary ctor <init>
665667
_empty_/Enums.Colour. => final object Colour
668+
_empty_/Enums.Colour.$fromOrdinal(). => method $fromOrdinal
669+
_empty_/Enums.Colour.$fromOrdinal().(_$ordinal) => param _$ordinal
666670
_empty_/Enums.Colour.$new(). => method $new
667671
_empty_/Enums.Colour.$new().($name) => param $name
668672
_empty_/Enums.Colour.$new().(_$ordinal) => param _$ordinal
@@ -676,6 +680,8 @@ _empty_/Enums.Colour.values(). => method values
676680
_empty_/Enums.Directions# => abstract sealed enum class Directions
677681
_empty_/Enums.Directions#`<init>`(). => primary ctor <init>
678682
_empty_/Enums.Directions. => final object Directions
683+
_empty_/Enums.Directions.$fromOrdinal(). => method $fromOrdinal
684+
_empty_/Enums.Directions.$fromOrdinal().(_$ordinal) => param _$ordinal
679685
_empty_/Enums.Directions.$new(). => method $new
680686
_empty_/Enums.Directions.$new().($name) => param $name
681687
_empty_/Enums.Directions.$new().(_$ordinal) => param _$ordinal
@@ -691,6 +697,8 @@ _empty_/Enums.Maybe# => abstract sealed enum class Maybe
691697
_empty_/Enums.Maybe#[A] => covariant typeparam A
692698
_empty_/Enums.Maybe#`<init>`(). => primary ctor <init>
693699
_empty_/Enums.Maybe. => final object Maybe
700+
_empty_/Enums.Maybe.$fromOrdinal(). => method $fromOrdinal
701+
_empty_/Enums.Maybe.$fromOrdinal().(_$ordinal) => param _$ordinal
694702
_empty_/Enums.Maybe.$values. => val method $values
695703
_empty_/Enums.Maybe.Just# => final case enum class Just
696704
_empty_/Enums.Maybe.Just#[A] => typeparam A
@@ -743,6 +751,8 @@ _empty_/Enums.Planet.values(). => method values
743751
_empty_/Enums.Suits# => abstract sealed enum class Suits
744752
_empty_/Enums.Suits#`<init>`(). => primary ctor <init>
745753
_empty_/Enums.Suits. => final object Suits
754+
_empty_/Enums.Suits.$fromOrdinal(). => method $fromOrdinal
755+
_empty_/Enums.Suits.$fromOrdinal().(_$ordinal) => param _$ordinal
746756
_empty_/Enums.Suits.$new(). => method $new
747757
_empty_/Enums.Suits.$new().($name) => param $name
748758
_empty_/Enums.Suits.$new().(_$ordinal) => param _$ordinal
@@ -763,6 +773,8 @@ _empty_/Enums.Tag# => abstract sealed enum class Tag
763773
_empty_/Enums.Tag#[A] => typeparam A
764774
_empty_/Enums.Tag#`<init>`(). => primary ctor <init>
765775
_empty_/Enums.Tag. => final object Tag
776+
_empty_/Enums.Tag.$fromOrdinal(). => method $fromOrdinal
777+
_empty_/Enums.Tag.$fromOrdinal().(_$ordinal) => param _$ordinal
766778
_empty_/Enums.Tag.$values. => val method $values
767779
_empty_/Enums.Tag.BooleanTag. => case val static enum method BooleanTag
768780
_empty_/Enums.Tag.IntTag. => case val static enum method IntTag
@@ -772,6 +784,8 @@ _empty_/Enums.Tag.values(). => method values
772784
_empty_/Enums.WeekDays# => abstract sealed enum class WeekDays
773785
_empty_/Enums.WeekDays#`<init>`(). => primary ctor <init>
774786
_empty_/Enums.WeekDays. => final object WeekDays
787+
_empty_/Enums.WeekDays.$fromOrdinal(). => method $fromOrdinal
788+
_empty_/Enums.WeekDays.$fromOrdinal().(_$ordinal) => param _$ordinal
775789
_empty_/Enums.WeekDays.$new(). => method $new
776790
_empty_/Enums.WeekDays.$new().($name) => param $name
777791
_empty_/Enums.WeekDays.$new().(_$ordinal) => param _$ordinal

0 commit comments

Comments
 (0)