Skip to content

Precise apply for enum companion objects #9728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
20 changes: 4 additions & 16 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ object desugar {
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree))
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, classTypeRef, creatorExpr)
.withMods(Modifiers(Synthetic | constr1.mods.flags & copiedAccessFlags, constr1.mods.privateWithin)) :: Nil
}
}
Expand Down Expand Up @@ -656,15 +656,6 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

// true if access to the apply method has to be restricted
// i.e. if the case class constructor is either private or qualified private
Expand Down Expand Up @@ -695,8 +686,6 @@ object desugar {
then anyRef
else
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods.is(Abstract)) Nil
else {
Expand All @@ -709,9 +698,8 @@ object desugar {
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
DefDef(nme.apply, derivedTparams, appParamss, classTypeRef, creatorExpr)
.withMods(appMods) :: Nil
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
Expand All @@ -720,7 +708,7 @@ object desugar {
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
val unapplyResTp = if arity == 0 then Literal(Constant(true)) else classTypeRef
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, unapplyResTp, unapplyRHS)
.withMods(synthetic)
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ object DesugarEnums {
CaseDef(Ident(nme.WILDCARD), EmptyTree,
Throw(New(TypeTree(defn.IllegalArgumentExceptionType), List(msg :: Nil))))
val stringCases = enumValues.map(enumValue =>
CaseDef(Literal(Constant(enumValue.name.toString)), EmptyTree, enumValue)
CaseDef(Literal(Constant(enumValue.name.toString)), EmptyTree, Typed(enumValue, rawEnumClassRef))
) ::: defaultCase :: Nil
Match(Ident(nme.nameDollar), stringCases)
val valueOfDef = DefDef(nme.valueOf, Nil, List(param(nme.nameDollar, defn.StringType) :: Nil),
Expand All @@ -157,12 +157,13 @@ object DesugarEnums {
def byOrdinal: List[Tree] =
if isJavaEnum || !constraints.cached then Nil
else
val rawEnumClassRef = rawRef(enumClass.typeRef)
val defaultCase =
val ord = Ident(nme.ordinal)
val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil)))
CaseDef(ord, EmptyTree, err)
val valueCases = constraints.enumCases.map((i, enumValue) =>
CaseDef(Literal(Constant(i)), EmptyTree, enumValue)
CaseDef(Literal(Constant(i)), EmptyTree, Typed(enumValue, rawEnumClassRef))
) ::: defaultCase :: Nil
val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil),
rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases))
Expand Down
19 changes: 14 additions & 5 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ trait ConstraintHandling {
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenEnum(tp: Type) =
val tpw = tp.widenEnumCase
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
Expand All @@ -354,14 +358,19 @@ trait ConstraintHandling {
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)

def isEnum(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isEnum(optBounds.bounds.hi)
case _ => tp.typeSymbol.is(Enum, butNot=JavaDefined)

val wideInst =
if isSingleton(bound) then inst
else dropSuperTraits(widenOr(widenSingle(inst)))
else
val lub = widenOr(widenSingle(inst))
val asAdt = if isEnum(bound) then lub else widenEnum(lub)
dropSuperTraits(asAdt)
wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
case _ =>
wideInst.dropRepeatedAnnot
case wideInst @ ModuleOrEnumValueRef() => wideInst
case wideInst => wideInst.dropRepeatedAnnot
end widenInferred

/** The instance type of `param` in the current constraint (which contains `param`).
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeErrors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,3 @@ object CyclicReference {
ex
}
}

31 changes: 27 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1099,10 +1099,22 @@ object Types {
case _ => this
}

/** same as widen, but preserves modules and singleton enum values */
final def widenToModule(using Context): Type =
def widenSingletonToModule(self: Type)(using Context): Type = self.stripTypeVar.stripAnnots match
case tp @ ModuleOrEnumValueRef() => tp
case tp: SingletonType if !tp.isOverloaded => widenSingletonToModule(tp.underlying)
case _ => self
widenSingletonToModule(this) match
case tp: ExprType => tp.resultType.widenToModule
case tp => tp

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping Expr types.
* Preserves references to modules or singleton enum values
*/
final def widenTermRefExpr(using Context): Type = stripTypeVar match {
case tp @ ModuleOrEnumValueRef() => tp
case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenTermRefExpr
case _ => this
}
Expand Down Expand Up @@ -1145,7 +1157,7 @@ object Types {
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
*/
def widenUnion(using Context): Type = widen match {
def widenUnion(using Context): Type = widenToModule match {
case tp @ OrNull(tp1): OrType =>
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
val tp1Widen = tp1.widenUnionWithoutNull
Expand All @@ -1155,11 +1167,11 @@ object Types {
tp.widenUnionWithoutNull
}

def widenUnionWithoutNull(using Context): Type = widen match {
def widenUnionWithoutNull(using Context): Type = widenToModule match {
case tp @ OrType(lhs, rhs) =>
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match {
case union: OrType => union.join
case res => res
case res => res
}
case tp @ AndType(tp1, tp2) =>
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
Expand All @@ -1173,13 +1185,19 @@ object Types {
tp
}

def widenEnumCase(using Context): Type = dealias match {
case tp: (TypeRef | AppliedType) if tp.typeSymbol.isAllOf(EnumCase) => tp.parents.head
case tp: TermRef if tp.termSymbol.isAllOf(EnumCase, butNot=JavaDefined) => tp.underlying.widenExpr
case _ => this
}

/** Widen all top-level singletons reachable by dealiasing
* and going to the operands of & and |.
* Overridden and cached in OrType.
*/
def widenSingletons(using Context): Type = dealias match {
case tp: SingletonType =>
tp.widen
tp.widenToModule
case tp: OrType =>
val tp1w = tp.widenSingletons
if (tp1w eq tp) this else tp1w
Expand Down Expand Up @@ -2548,6 +2566,11 @@ object Types {
apply(prefix, designatorFor(prefix, name, denot)).withDenot(denot)
}

object ModuleOrEnumValueRef:
def unapply(tp: TermRef)(using Context): Boolean =
val sym = tp.termSymbol
sym.isAllOf(EnumCase, butNot=JavaDefined) || sym.is(Module)

object TypeRef {

/** Create a type ref with given prefix and name */
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Scanners {

object IndentWidth {
private inline val MaxCached = 40
private val spaces = Array.tabulate(MaxCached + 1)(new Run(' ', _))
private val tabs = Array.tabulate(MaxCached + 1)(new Run('\t', _))
private val spaces = Array.tabulate[Run](MaxCached + 1)(new Run(' ', _)) // TODO: remove new after bootstrap
private val tabs = Array.tabulate[Run](MaxCached + 1)(new Run('\t', _)) // TODO: remove new after bootstrap

def Run(ch: Char, n: Int): Run =
if (n <= MaxCached && ch == ' ') spaces(n)
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
toTextRHS(tp)
case tp: TermRef
if !tp.denotationIsCurrent && !homogenizedView || // always print underlying when testing picklers
tp.symbol.is(Module) || tp.symbol.name == nme.IMPORT =>
tp.symbol.is(Module) || tp.symbol.isAllOf(EnumCase) || tp.symbol.name == nme.IMPORT =>
toTextRef(tp) ~ ".type"
case tp: TermRef if tp.denot.isOverloaded =>
"<overloaded " ~ toTextRef(tp) ~ ">"
Expand Down Expand Up @@ -598,4 +598,3 @@ class PlainPrinter(_ctx: Context) extends Printer {
protected def coloredText(text: Text, color: String): Text =
if (ctx.useColors) color ~ text ~ SyntaxHighlighting.NoColor else text
}

7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ class ReifyQuotes extends MacroTransform {
val meth =
if (isType) ref(defn.Unpickler_unpickleType).appliedToType(originalTp)
else
val tpe =
if originalTp =:= defn.NilModule.termRef then originalTp // Workaround #4987
else originalTp.widen.dealias
val tpe = originalTp.widenToModule.dealias
// val tpe =
// if originalTp =:= defn.NilModule.termRef then originalTp // Workaround #4987
// else originalTp.widen.dealias
ref(defn.Unpickler_unpickleExpr).appliedToType(tpe)
val pickledQuoteStrings = liftList(PickledQuotes.pickleQuote(body).map(x => Literal(Constant(x))), defn.StringType)
val splicesList = liftList(splices, defn.FunctionType(1).appliedTo(defn.SeqType.appliedTo(defn.AnyType), defn.AnyType))
Expand Down
94 changes: 47 additions & 47 deletions tests/patmat/i7186.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,78 +125,78 @@ object printMips {
case Syscall =>
s"${indent}syscall$endl"
case jal: Jal =>
oneAddr(jal,indent)(_.dest)
oneAddr(jal,indent)(_.enumLabel,_.dest)
case jr: Jr =>
oneAddr(jr,indent)(_.dest)
oneAddr(jr,indent)(_.enumLabel,_.dest)
case j: J =>
oneAddr(j,indent)(_.dest)
oneAddr(j,indent)(_.enumLabel,_.dest)
case li: Li =>
twoAddr(li,indent)(_.dest,_.source)
twoAddr(li,indent)(_.enumLabel,_.dest,_.source)
case lw: Lw =>
twoAddr(lw,indent)(_.dest,_.source)
twoAddr(lw,indent)(_.enumLabel,_.dest,_.source)
case neg: Neg =>
twoAddr(neg,indent)(_.dest,_.r)
twoAddr(neg,indent)(_.enumLabel,_.dest,_.r)
case not: Not =>
twoAddr(not,indent)(_.dest,_.r)
twoAddr(not,indent)(_.enumLabel,_.dest,_.r)
case move: Move =>
twoAddr(move,indent)(_.dest,_.source)
twoAddr(move,indent)(_.enumLabel,_.dest,_.source)
case beqz: Beqz =>
twoAddr(beqz,indent)(_.source,_.breakTo)
twoAddr(beqz,indent)(_.enumLabel,_.source,_.breakTo)
case sw: Sw =>
twoAddr(sw,indent)(_.source,_.dest)
twoAddr(sw,indent)(_.enumLabel,_.source,_.dest)
case add: Add =>
threeAddr(add,indent)(_.dest,_.l,_.r)
threeAddr(add,indent)(_.enumLabel,_.dest,_.l,_.r)
case sub: Sub =>
threeAddr(sub,indent)(_.dest,_.l,_.r)
threeAddr(sub,indent)(_.enumLabel,_.dest,_.l,_.r)
case mul: Mul =>
threeAddr(mul,indent)(_.dest,_.l,_.r)
threeAddr(mul,indent)(_.enumLabel,_.dest,_.l,_.r)
case div: Div =>
threeAddr(div,indent)(_.dest,_.l,_.r)
threeAddr(div,indent)(_.enumLabel,_.dest,_.l,_.r)
case rem: Rem =>
threeAddr(rem,indent)(_.dest,_.l,_.r)
threeAddr(rem,indent)(_.enumLabel,_.dest,_.l,_.r)
case seq: Seq =>
threeAddr(seq,indent)(_.dest,_.l,_.r)
threeAddr(seq,indent)(_.enumLabel,_.dest,_.l,_.r)
case sne: Sne =>
threeAddr(sne,indent)(_.dest,_.l,_.r)
threeAddr(sne,indent)(_.enumLabel,_.dest,_.l,_.r)
case slt: Slt =>
threeAddr(slt,indent)(_.dest,_.l,_.r)
threeAddr(slt,indent)(_.enumLabel,_.dest,_.l,_.r)
case sgt: Sgt =>
threeAddr(sgt,indent)(_.dest,_.l,_.r)
threeAddr(sgt,indent)(_.enumLabel,_.dest,_.l,_.r)
case sle: Sle =>
threeAddr(sle,indent)(_.dest,_.l,_.r)
threeAddr(sle,indent)(_.enumLabel,_.dest,_.l,_.r)
case sge: Sge =>
threeAddr(sge,indent)(_.dest,_.l,_.r)
threeAddr(sge,indent)(_.enumLabel,_.dest,_.l,_.r)
case _ => s"${indent}???$endl"
}
}

def oneAddr[O]
( a: O, indent: String)
( r: O => Dest,
): String = {
val name = a.getClass.getSimpleName.toLowerCase
s"${indent}$name ${rsrc(r(a))}$endl"
}

def twoAddr[O]
( a: O, indent: String)
( d: O => Register,
r: O => Dest | Constant
): String = {
val name = a.getClass.getSimpleName.toLowerCase
s"${indent}$name ${registers(d(a))}, ${rsrc(r(a))}$endl"
}

def threeAddr[O]
( a: O,
def oneAddr[T]
( a: T, indent: String )
( c: a.type => String,
r: a.type => Dest,
): String = (
s"${indent}${c(a)} ${rsrc(r(a))}$endl"
)

def twoAddr[T]
( a: T, indent: String )
( c: a.type => String,
d: a.type => Register,
r: a.type => Dest | Constant
): String = (
s"${indent}${c(a)} ${registers(d(a))}, ${rsrc(r(a))}$endl"
)

def threeAddr[T]
( a: T,
indent: String )
( d: O => Register,
l: O => Register,
r: O => Src
): String = {
val name = a.getClass.getSimpleName.toLowerCase
s"${indent}$name ${registers(d(a))}, ${registers(l(a))}, ${rsrc(r(a))}$endl"
}
( c: a.type => String,
d: a.type => Register,
l: a.type => Register,
r: a.type => Src
): String = (
s"${indent}${c(a)} ${registers(d(a))}, ${registers(l(a))}, ${rsrc(r(a))}$endl"
)

def rsrc(v: Constant | Dest): String = v match {
case Constant(c) => c.toString
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/i3935.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
enum Foo3[T](x: T) {
case Bar[S, T](y: T) extends Foo3[y.type](y)
}

val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
val bar = foo

def baz[T](f: Foo3[T]): f.type = f

val qux = baz(bar) // existentials are back in Dotty?
12 changes: 12 additions & 0 deletions tests/pos/i6781.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
enum Nat {
case Zero
case Succ[N <: Nat](n: N)
}
import Nat._

inline def toInt(n: => Nat): Int = inline n match {
case Zero => 0
case Succ(n1) => toInt(n1) + 1
}

val natTwo = toInt(Succ(Succ(Zero)))
Loading