Skip to content

fix #7227: allow custom toString on enum #9549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ object desugar {
yield syntheticProperty(selName, caseParams(i).tpt,
Select(This(EmptyTypeIdent), caseParams(i).name))

def ordinalMeths = if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: Nil else Nil
def enumMeths =
if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil
else Nil
def copyMeths = {
val hasRepeatedParam = constrVparamss.exists(_.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
Expand All @@ -605,7 +607,7 @@ object desugar {
}

if (isCaseClass)
copyMeths ::: ordinalMeths ::: productElemMeths
copyMeths ::: enumMeths ::: productElemMeths
else Nil
}

Expand Down
37 changes: 18 additions & 19 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,21 @@ object DesugarEnums {
/** A creation method for a value of enum type `E`, which is defined as follows:
*
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
* def ordinal = _$ordinal // if `E` does not derive from jl.Enum
* override def toString = $name // if `E` does not derive from jl.Enum
* def ordinal = _$ordinal // if `E` does not derive from `java.lang.Enum`
* def enumLabel = $name // if `E` does not derive from `java.lang.Enum`
* def enumLabel = this.name // if `E` derives from `java.lang.Enum`
* $values.register(this)
* }
*/
private def enumValueCreator(using Context) = {
val fieldMethods =
if isJavaEnum then Nil
else
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
val toStringDef = toStringMeth(Ident(nme.nameDollar))
List(ordinalDef, toStringDef)
if isJavaEnum then
val enumLabelDef = enumLabelMeth(Select(This(Ident(tpnme.EMPTY)), nme.name))
enumLabelDef :: Nil
else
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
val enumLabelDef = enumLabelMeth(Ident(nme.nameDollar))
ordinalDef :: enumLabelDef :: Nil
val creator = New(Template(
constr = emptyConstructor,
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
Expand Down Expand Up @@ -273,14 +276,14 @@ object DesugarEnums {
def ordinalMeth(body: Tree)(using Context): DefDef =
DefDef(nme.ordinal, Nil, Nil, TypeTree(defn.IntType), body)

def toStringMeth(body: Tree)(using Context): DefDef =
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override)
def enumLabelMeth(body: Tree)(using Context): DefDef =
DefDef(nme.enumLabel, Nil, Nil, TypeTree(defn.StringType), body)

def ordinalMethLit(ord: Int)(using Context): DefDef =
ordinalMeth(Literal(Constant(ord)))

def toStringMethLit(name: String)(using Context): DefDef =
toStringMeth(Literal(Constant(name)))
def enumLabelLit(name: String)(using Context): DefDef =
enumLabelMeth(Literal(Constant(name)))

/** Expand a module definition representing a parameterless enum case */
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
Expand All @@ -290,16 +293,12 @@ object DesugarEnums {
expandSimpleEnumCase(name, mods, span)
else {
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
val fieldMethods =
if isJavaEnum then Nil
else
val ordinalDef = ordinalMethLit(tag)
val toStringDef = toStringMethLit(name.toString)
List(ordinalDef, toStringDef)
val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
val enumLabelDef = enumLabelLit(name.toString)
val impl1 = cpy.Template(impl)(
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
body = fieldMethods ::: registerCall :: Nil)
.withAttachment(ExtendsSingletonMirror, ())
body = ordinalDef ::: enumLabelDef :: registerCall :: Nil
).withAttachment(ExtendsSingletonMirror, ())
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
}
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ class Definitions {
@tu lazy val NoneModule: Symbol = requiredModule("scala.None")

@tu lazy val EnumClass: ClassSymbol = requiredClass("scala.Enum")
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)

@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ object StdNames {
val emptyValDef: N = "emptyValDef"
val end: N = "end"
val ensureAccessible : N = "ensureAccessible"
val enumLabel: N = "enumLabel"
val eq: N = "eq"
val eqInstance: N = "eqInstance"
val equalsNumChar : N = "equalsNumChar"
Expand Down
21 changes: 20 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
private var myValueSymbols: List[Symbol] = Nil
private var myCaseSymbols: List[Symbol] = Nil
private var myCaseModuleSymbols: List[Symbol] = Nil
private var myEnumValueSymbols: List[Symbol] = Nil
private var myNonJavaEnumValueSymbols: List[Symbol] = Nil

private def initSymbols(using Context) =
if (myValueSymbols.isEmpty) {
Expand All @@ -65,11 +67,15 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement,
defn.Product_productElementName)
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
myEnumValueSymbols = List(defn.Product_productPrefix)
myNonJavaEnumValueSymbols = myEnumValueSymbols :+ defn.Any_toString
}

def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
def caseSymbols(using Context): List[Symbol] = { initSymbols; myCaseSymbols }
def caseModuleSymbols(using Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
def enumValueSymbols(using Context): List[Symbol] = { initSymbols; myEnumValueSymbols }
def nonJavaEnumValueSymbols(using Context): List[Symbol] = { initSymbols; myNonJavaEnumValueSymbols }

private def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = {
val existing = sym.matchingMember(clazz.thisType)
Expand All @@ -89,11 +95,15 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
else clazz.caseAccessors
val isEnumCase = clazz.derivesFrom(defn.EnumClass) && clazz != defn.EnumClass
val isEnumValue = isEnumCase && clazz.isAnonymousClass && clazz.classParents.head.classSymbol.is(Enum)
val isNonJavaEnumValue = isEnumValue && !clazz.derivesFrom(defn.JavaEnumClass)

val symbolsToSynthesize: List[Symbol] =
if (clazz.is(Case))
if (clazz.is(Module)) caseModuleSymbols
else caseSymbols
else if (isNonJavaEnumValue) nonJavaEnumValueSymbols
else if (isEnumValue) enumValueSymbols
else if (isDerivedValueClass(clazz)) valueSymbols
else Nil

Expand All @@ -113,13 +123,22 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
def ownName: Tree =
Literal(Constant(clazz.name.stripModuleClassSuffix.toString))

def callEnumLabel: Tree =
Select(This(clazz), nme.enumLabel).ensureApplied

def toStringBody(vrefss: List[List[Tree]]): Tree =
if (clazz.is(ModuleClass)) ownName
else if (isNonJavaEnumValue) callEnumLabel
else forwardToRuntime(vrefss.head)

def syntheticRHS(vrefss: List[List[Tree]])(using Context): Tree = synthetic.name match {
case nme.hashCode_ if isDerivedValueClass(clazz) => valueHashCodeBody
case nme.hashCode_ => chooseHashcode
case nme.toString_ => if (clazz.is(ModuleClass)) ownName else forwardToRuntime(vrefss.head)
case nme.toString_ => toStringBody(vrefss)
case nme.equals_ => equalsBody(vrefss.head.head)
case nme.canEqual_ => canEqualBody(vrefss.head.head)
case nme.productArity => Literal(Constant(accessors.length))
case nme.productPrefix if isEnumValue => callEnumLabel
case nme.productPrefix => ownName
case nme.productElement => productElementBody(accessors.length, vrefss.head.head)
case nme.productElementName => productElementNameBody(accessors.length, vrefss.head.head)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ trait Checking {

end checkEnumParent


/** Check that all references coming from enum cases in an enum companion object
* are legal.
* @param cdef the enum companion object class
Expand Down
12 changes: 8 additions & 4 deletions docs/docs/reference/enums/desugarEnums.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,19 @@ If `E` contains at least one simple case, its companion object will define in ad
follows.
```scala
private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue {
def ordinal = _$ordinal // if `E` does not have `java.lang.Enum` as a parent
override def toString = $name // if `E` does not have `java.lang.Enum` as a parent
def ordinal = _$ordinal
def enumLabel = $name
override def productPrefix = enumLabel // if not overridden in `E`
override def toString = enumLabel // if not overridden in `E`
$values.register(this) // register enum value so that `valueOf` and `values` can return it.
}
```

The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in
`java.lang.Enum`.
The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend
`java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as
`java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in
`java.lang.Enum`. Finally, `enumLabel` will call `this.name` when `E` extends `java.lang.Enum`.

### Scopes for Enum Cases

Expand Down
9 changes: 7 additions & 2 deletions docs/docs/reference/enums/enums.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,17 @@ For a more in-depth example of using Scala 3 enums from Java, see [this test](ht
### Implementation

Enums are represented as `sealed` classes that extend the `scala.Enum` trait.
This trait defines a single public method, `ordinal`:
This trait defines two public methods, `ordinal` and `enumLabel`:

```scala
package scala

/** A base trait of all enum classes */
trait Enum extends Product with Serializable {

/** A string uniquely identifying a case of an enum */
def enumLabel: String

/** A number uniquely identifying a case of an enum */
def ordinal: Int
}
Expand All @@ -130,7 +133,9 @@ For instance, the `Venus` value above would be defined like this:
val Venus: Planet =
new Planet(4.869E24, 6051800.0) {
def ordinal: Int = 1
override def toString: String = "Venus"
def enumLabel: String = "Venus"
override def productPrefix: String = enumLabel
override def toString: String = enumLabel
// internal code to register value
}
```
Expand Down
3 changes: 3 additions & 0 deletions library/src-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ package scala
/** A base trait of all enum classes */
trait Enum extends Product, Serializable:

/** A string uniquely identifying a case of an enum */
def enumLabel: String

/** A number uniquely identifying a case of an enum */
def ordinal: Int
21 changes: 21 additions & 0 deletions library/src-bootstrapped/scala/runtime/EnumValues.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package scala.runtime

import scala.collection.immutable.TreeMap

class EnumValues[E <: Enum] {
private[this] var myMap: Map[Int, E] = TreeMap.empty
private[this] var fromNameCache: Map[String, E] = null

def register(v: E) = {
require(!myMap.contains(v.ordinal))
myMap = myMap.updated(v.ordinal, v)
fromNameCache = null
}

def fromInt: Map[Int, E] = myMap
def fromName: Map[String, E] = {
if (fromNameCache == null) fromNameCache = myMap.values.map(v => v.enumLabel -> v).toMap
fromNameCache
}
def values: Iterable[E] = myMap.values
}
1 change: 0 additions & 1 deletion library/src/scala/runtime/EnumValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package scala.runtime
super trait EnumValue extends Product, Serializable:
override def canEqual(that: Any) = this eq that.asInstanceOf[AnyRef]
override def productArity: Int = 0
override def productPrefix: String = toString
override def productElement(n: Int): Any =
throw IndexOutOfBoundsException(n.toString)
override def productElementName(n: Int): String =
Expand Down
19 changes: 19 additions & 0 deletions tests/neg/enumsLabelDef.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
enum Labelled {

case A // error overriding method enumLabel in class Labelled of type => String;
case B(arg: Int) // error overriding method enumLabel in class Labelled of type => String;

def enumLabel: String = "nolabel"
}

trait Mixin { def enumLabel: String = "mixin" }

enum Mixed extends Mixin {
case C // error overriding method enumLabel in trait Mixin of type => String;
}

trait HasEnumLabel { def enumLabel: String }

enum MyEnum extends HasEnumLabel {
case D // ok
}
2 changes: 2 additions & 0 deletions tests/pos/enum-List-control.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ abstract sealed class List[T] extends Enum
object List {
final class Cons[T](x: T, xs: List[T]) extends List[T] {
def ordinal = 0
def enumLabel = "Cons"
def canEqual(that: Any): Boolean = that.isInstanceOf[Cons[_]]
def productArity: Int = 2
def productElement(n: Int): Any = n match
Expand All @@ -13,6 +14,7 @@ object List {
}
final class Nil[T]() extends List[T], runtime.EnumValue {
def ordinal = 1
def enumLabel = "Nil"
}
object Nil {
def apply[T](): List[T] = new Nil()
Expand Down
92 changes: 92 additions & 0 deletions tests/run/enum-custom-toString.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
enum ES:
case A
override def toString: String = "overridden"

enum EJ extends java.lang.Enum[EJ]:
case B
override def toString: String = "overridden"

trait Mixin extends Enum:
override def productPrefix: String = "noprefix"
override def toString: String = "overridden"

enum EM extends Mixin:
case C

enum ET[T] extends java.lang.Enum[ET[_]]:
case D extends ET[Unit]
override def toString: String = "overridden"

enum EZ:
case E(arg: Int)
override def toString: String = "overridden"

enum EC: // control case
case F
case G(arg: Int)

enum EO:
case H
case I(arg: Int)
override def productPrefix: String = "noprefix"
override def toString: String = "overridden"
end EO

enum EQ:
case J extends EQ with Mixin
case K(arg: Int) extends EQ with Mixin

abstract class Tag[T] extends Enum
object Tag:
private final class IntTagImpl extends Tag[Int] with runtime.EnumValue:
def ordinal = 0
def enumLabel = "IntTag"
override def hashCode = 123
final val IntTag: Tag[Int] = IntTagImpl()

@main def Test =
assert(ES.A.toString == "overridden", s"ES.A.toString = ${ES.A.toString}")
assert(ES.A.productPrefix == "A", s"ES.A.productPrefix = ${ES.A.productPrefix}")
assert(ES.A.enumLabel == "A", s"ES.A.enumLabel = ${ES.A.enumLabel}")
assert(ES.valueOf("A") == ES.A, s"ES.valueOf(A) = ${ES.valueOf("A")}")
assert(EJ.B.toString == "overridden", s"EJ.B.toString = ${EJ.B.toString}")
assert(EJ.B.productPrefix == "B", s"EJ.B.productPrefix = ${EJ.B.productPrefix}")
assert(EJ.B.enumLabel == "B", s"EJ.B.enumLabel = ${EJ.B.enumLabel}")
assert(EJ.valueOf("B") == EJ.B, s"EJ.valueOf(B) = ${EJ.valueOf("B")}")
assert(EM.C.toString == "overridden", s"EM.C.toString = ${EM.C.toString}")
assert(EM.C.productPrefix == "noprefix", s"EM.C.productPrefix = ${EM.C.productPrefix}")
assert(EM.C.enumLabel == "C", s"EM.C.enumLabel = ${EM.C.enumLabel}")
assert(EM.valueOf("C") == EM.C, s"EM.valueOf(C) = ${EM.valueOf("C")}")
assert(ET.D.toString == "overridden", s"ET.D.toString = ${ET.D.toString}")
assert(ET.D.productPrefix == "D", s"ET.D.productPrefix = ${ET.D.productPrefix}")
assert(ET.D.enumLabel == "D", s"ET.D.enumLabel = ${ET.D.enumLabel}")
assert(EZ.E(0).toString == "overridden", s"EZ.E(0).toString = ${EZ.E(0).toString}")
assert(EZ.E(0).productPrefix == "E", s"EZ.E(0).productPrefix = ${EZ.E(0).productPrefix}")
assert(EZ.E(0).enumLabel == "E", s"EZ.E(0).enumLabel = ${EZ.E(0).enumLabel}")
assert(EC.F.toString == "F", s"EC.F.toString = ${EC.F.toString}")
assert(EC.F.productPrefix == "F", s"EC.F.productPrefix = ${EC.F.productPrefix}")
assert(EC.F.enumLabel == "F", s"EC.F.enumLabel = ${EC.F.enumLabel}")
assert(EC.valueOf("F") == EC.F, s"EC.valueOf(F) = ${EC.valueOf("F")}")
assert(EC.G(0).toString == "G(0)", s"EC.G(0).toString = ${EC.G(0).toString}")
assert(EC.G(0).productPrefix == "G", s"EC.G(0).productPrefix = ${EC.G(0).productPrefix}")
assert(EC.G(0).enumLabel == "G", s"EC.G(0).enumLabel = ${EC.G(0).enumLabel}")
assert(EO.H.toString == "overridden", s"EO.H.toString = ${EO.H.toString}")
assert(EO.H.productPrefix == "noprefix", s"EO.H.productPrefix = ${EO.H.productPrefix}")
assert(EO.H.enumLabel == "H", s"EO.H.enumLabel = ${EO.H.enumLabel}")
assert(EO.valueOf("H") == EO.H, s"EO.valueOf(H) = ${EO.valueOf("H")}")
assert(EO.I(0).toString == "overridden", s"EO.I(0).toString = ${EO.I(0).toString}")
assert(EO.I(0).productPrefix == "noprefix", s"EO.I(0).productPrefix = ${EO.I(0).productPrefix}")
assert(EO.I(0).enumLabel == "I", s"EO.I(0).enumLabel = ${EO.I(0).enumLabel}")
assert(EQ.J.toString == "overridden", s"EQ.J.toString = ${EQ.J.toString}")
assert(EQ.J.productPrefix == "noprefix", s"EQ.J.productPrefix = ${EQ.J.productPrefix}")
assert(EQ.J.enumLabel == "J", s"EQ.J.enumLabel = ${EQ.J.enumLabel}")
assert(EQ.valueOf("J") == EQ.J, s"EQ.valueOf(J) = ${EQ.valueOf("J")}")
assert(EQ.K(0).toString == "overridden", s"EQ.K(0).toString = ${EQ.K(0).toString}")
assert(EQ.K(0).productPrefix == "noprefix", s"EQ.K(0).productPrefix = ${EQ.K(0).productPrefix}")
assert(EQ.K(0).enumLabel == "K", s"EQ.K(0).enumLabel = ${EQ.K(0).enumLabel}")
assert(Tag.IntTag.productPrefix == "", s"Tag.IntTag.productPrefix = ${Tag.IntTag.productPrefix}")
assert(Tag.IntTag.enumLabel == "IntTag", s"Tag.IntTag.enumLabel = ${Tag.IntTag.enumLabel}")
assert(
assertion = Tag.IntTag.toString == s"${Tag.IntTag.getClass.getName}@${Integer.toHexString(123)}",
message = s"Tag.IntTag.toString = ${Tag.IntTag.toString}"
)
Loading