Skip to content

Fix #7499: Prevent extending java.lang.Enum except from an enum #9487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,18 @@ class Definitions {

@tu lazy val JavaEnumClass: ClassSymbol = {
val cls = requiredClass("java.lang.Enum")
// jl.Enum has a single constructor protected(name: String, ordinal: Int).
// We remove the arguments from the primary constructor, and enter
// a new constructor symbol with 2 arguments, so that both
// `X extends jl.Enum[X]` and `X extends jl.Enum[X](name, ordinal)`
// pass typer and go through jl.Enum-specific checks in RefChecks.
cls.infoOrCompleter match {
case completer: ClassfileLoader =>
cls.info = new ClassfileLoader(completer.classfile) {
override def complete(root: SymDenotation)(using Context): Unit = {
super.complete(root)
val constr = cls.primaryConstructor
val newInfo = constr.info match {
val noArgInfo = constr.info match {
case info: PolyType =>
info.resType match {
case meth: MethodType =>
Expand All @@ -600,7 +605,8 @@ class Definitions {
paramNames = Nil, paramInfos = Nil))
}
}
constr.info = newInfo
val argConstr = constr.copy().entered
constr.info = noArgInfo
constr.termRef.recomputeDenot()
}
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ enum ErrorMessageID extends java.lang.Enum[ErrorMessageID] {
UnexpectedPatternForSummonFromID,
AnonymousInstanceCannotBeEmptyID,
TypeSpliceInValPatternID,
ModifierNotAllowedForDefinitionID
ModifierNotAllowedForDefinitionID,
CannotExtendJavaEnumID

def errorNumber = ordinal - 2
}
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,12 @@ import ast.tpd
|"""
}

class CannotExtendJavaEnum(sym: Symbol)(using Context)
extends SyntaxMsg(CannotExtendJavaEnumID) {
def msg = em"""$sym cannot extend ${hl("java.lang.Enum")}: only enums defined with the ${hl("enum")} syntax can"""
def explain = ""
}

class CannotHaveSameNameAs(sym: Symbol, cls: Symbol, reason: CannotHaveSameNameAs.Reason)(using Context)
extends SyntaxMsg(CannotHaveSameNameAsID) {
import CannotHaveSameNameAs._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
*/
private def addEnumConstrArgs(targetCls: Symbol, parents: List[Tree], args: List[Tree])(using Context): List[Tree] =
parents.map {
case app @ Apply(fn, args0) if fn.symbol.owner == targetCls => cpy.Apply(app)(fn, args0 ++ args)
case app @ Apply(fn, args0) if fn.symbol.owner == targetCls =>
if args0.nonEmpty && targetCls == defn.JavaEnumClass then
report.error("the constructor of java.lang.Enum cannot be called explicitly", app.sourcePos)
cpy.Apply(app)(fn, args0 ++ args)
case p => p
}

Expand Down
18 changes: 14 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import scala.util.{Try, Failure, Success}
import config.{ScalaVersion, NoScalaVersion}
import Decorators._
import typer.ErrorReporting._
import config.Feature.warnOnMigration
import config.Feature.{warnOnMigration, migrateTo3}
import reporting._

object RefChecks {
Expand Down Expand Up @@ -88,8 +88,9 @@ object RefChecks {
cls.thisType
}

/** Check that self type of this class conforms to self types of parents.
* and required classes.
/** Check that self type of this class conforms to self types of parents
* and required classes. Also check that only `enum` constructs extend
* `java.lang.Enum`.
*/
private def checkParents(cls: Symbol)(using Context): Unit = cls.info match {
case cinfo: ClassInfo =>
Expand All @@ -99,10 +100,19 @@ object RefChecks {
report.error(DoesNotConformToSelfType(category, cinfo.selfType, cls, otherSelf, relation, other.classSymbol),
cls.sourcePos)
}
for (parent <- cinfo.classParents)
val parents = cinfo.classParents
for (parent <- parents)
checkSelfConforms(parent.classSymbol.asClass, "illegal inheritance", "parent")
for (reqd <- cinfo.cls.givenSelfType.classSymbols)
checkSelfConforms(reqd, "missing requirement", "required")

// Prevent wrong `extends` of java.lang.Enum
if !migrateTo3 &&
!cls.isOneOf(Enum | Trait) &&
parents.exists(_.classSymbol == defn.JavaEnumClass)
then
report.error(CannotExtendJavaEnum(cls), cls.sourcePos)

case _ =>
}

Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class CompilationTests extends ParallelTesting {
compileFile("tests/pos-custom-args/i5498-postfixOps.scala", defaultOptions withoutLanguageFeature "postfixOps"),
compileFile("tests/pos-custom-args/i8875.scala", defaultOptions.and("-Xprint:getters")),
compileFile("tests/pos-custom-args/i9267.scala", defaultOptions.and("-Ystop-after:erasure")),
compileFile("tests/pos-special/extend-java-enum.scala", defaultOptions.and("-source", "3.0-migration")),
).checkCompile()
}

Expand Down
3 changes: 3 additions & 0 deletions tests/neg/cannot-call-java-enum-constructor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
enum E extends java.lang.Enum[E]("name", 0) { // error
case A, B
}
9 changes: 9 additions & 0 deletions tests/neg/extend-java-enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import java.{lang => jl}

class C1 extends jl.Enum[C1] // error: class C1 cannot extend java.lang.Enum

class C2(name: String, ordinal: Int) extends jl.Enum[C2](name, ordinal) // error: class C2 cannot extend java.lang.Enum

trait T extends jl.Enum[T] // ok

class C3 extends T // error: class C3 cannot extend java.lang.Enum
21 changes: 21 additions & 0 deletions tests/pos-special/extend-java-enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import java.{lang => jl}

final class ConfigSyntax private (name: String, ordinal: Int)
extends jl.Enum[ConfigSyntax](name, ordinal)

object ConfigSyntax {

final val JSON = new ConfigSyntax("JSON", 0)
final val CONF = new ConfigSyntax("CONF", 1)
final val PROPERTIES = new ConfigSyntax("PROPERTIES", 2)

private[this] final val _values: Array[ConfigSyntax] =
Array(JSON, CONF, PROPERTIES)

def values: Array[ConfigSyntax] = _values.clone()

def valueOf(name: String): ConfigSyntax =
_values.find(_.name == name).getOrElse {
throw new IllegalArgumentException("No enum const ConfigSyntax." + name)
}
}
5 changes: 5 additions & 0 deletions tests/pos/trait-java-enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
trait T extends java.lang.Enum[T]

enum MyEnum extends T {
case A, B
}