Skip to content
20 changes: 4 additions & 16 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ object desugar {
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree))
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, classTypeRef, creatorExpr)
.withMods(Modifiers(Synthetic | constr1.mods.flags & copiedAccessFlags, constr1.mods.privateWithin)) :: Nil
}
}
Expand Down Expand Up @@ -656,15 +656,6 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

// true if access to the apply method has to be restricted
// i.e. if the case class constructor is either private or qualified private
Expand Down Expand Up @@ -695,8 +686,6 @@ object desugar {
then anyRef
else
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods.is(Abstract)) Nil
else {
Expand All @@ -709,9 +698,8 @@ object desugar {
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
DefDef(nme.apply, derivedTparams, appParamss, classTypeRef, creatorExpr)
.withMods(appMods) :: Nil
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
Expand All @@ -720,7 +708,7 @@ object desugar {
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
val unapplyResTp = if arity == 0 then Literal(Constant(true)) else classTypeRef
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, unapplyResTp, unapplyRHS)
.withMods(synthetic)
}
Expand Down
12 changes: 10 additions & 2 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ trait ConstraintHandling {
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenEnum(tp: Type) =
val tpw = tp.widenEnumClass
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
Expand All @@ -354,9 +358,13 @@ trait ConstraintHandling {
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)

def isEnum(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isEnum(optBounds.bounds.hi)
case _ => tp.typeSymbol.is(Enum, butNot=JavaDefined)

val wideInst =
if isSingleton(bound) then inst
else dropSuperTraits(widenOr(widenSingle(inst)))
if isSingleton(bound) || isEnum(bound) then inst
else dropSuperTraits(widenOr(widenEnum(widenSingle(inst))))
Copy link
Member

@smarter smarter Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's still not quite right: if the upper bound is an enum, a singleton should still be widened unless it's the type of an enum case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean remove singletons except the term ref of a singleton enum case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to get a lot trickier when unions of singletons are involved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new steps are widenSingle > widenOr > [widenEnumCase] > dropSuperTraits, where widenOr is intercepted so that singletons of module or enum value do not widen

wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,10 @@ object Types {

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping Expr types.
* Preserves references to modules or singleton enum values
*/
final def widenTermRefExpr(using Context): Type = stripTypeVar match {
case tp: TermRef if tp.termSymbol.isAllOf(EnumCase) || tp.termSymbol.is(Module) => tp
case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenTermRefExpr
case _ => this
}
Expand Down Expand Up @@ -1173,6 +1175,13 @@ object Types {
tp
}

def widenEnumClass(using Context): Type = dealias match {
case tp: (TypeRef | AppliedType) if tp.typeSymbol.isAllOf(EnumCase) =>
tp.parents.head
case _ =>
this
}

/** Widen all top-level singletons reachable by dealiasing
* and going to the operands of & and |.
* Overridden and cached in OrType.
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Scanners {

object IndentWidth {
private inline val MaxCached = 40
private val spaces = Array.tabulate(MaxCached + 1)(new Run(' ', _))
private val tabs = Array.tabulate(MaxCached + 1)(new Run('\t', _))
private val spaces = Array.tabulate[Run](MaxCached + 1)(new Run(' ', _)) // TODO: remove new after bootstrap
private val tabs = Array.tabulate[Run](MaxCached + 1)(new Run('\t', _)) // TODO: remove new after bootstrap

def Run(ch: Char, n: Int): Run =
if (n <= MaxCached && ch == ' ') spaces(n)
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
toTextRHS(tp)
case tp: TermRef
if !tp.denotationIsCurrent && !homogenizedView || // always print underlying when testing picklers
tp.symbol.is(Module) || tp.symbol.name == nme.IMPORT =>
tp.symbol.is(Module) || tp.symbol.isAllOf(EnumCase) || tp.symbol.name == nme.IMPORT =>
toTextRef(tp) ~ ".type"
case tp: TermRef if tp.denot.isOverloaded =>
"<overloaded " ~ toTextRef(tp) ~ ">"
Expand Down Expand Up @@ -598,4 +598,3 @@ class PlainPrinter(_ctx: Context) extends Printer {
protected def coloredText(text: Text, color: String): Text =
if (ctx.useColors) color ~ text ~ SyntaxHighlighting.NoColor else text
}

10 changes: 10 additions & 0 deletions tests/pos/i3935.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
enum Foo3[T](x: T) {
case Bar[S, T](y: T) extends Foo3[y.type](y)
}

val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
val bar = foo

def baz[T](f: Foo3[T]): f.type = f

val qux = baz(bar) // existentials are back in Dotty?
4 changes: 4 additions & 0 deletions tests/run-macros/i8007.check
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@ true

true

true

true

false

3 changes: 1 addition & 2 deletions tests/run-macros/i8007/Macro_3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ object Eq {
$ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y)
}
}

'{
eqSum((x: T, y: T) => ${eqSumBody('x, 'y)})
}
Expand All @@ -76,4 +75,4 @@ object Macro3 {
extension [T](x: =>T) inline def === (y: =>T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)

implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
}
}
30 changes: 25 additions & 5 deletions tests/run-macros/i8007/Test_4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@ import Macro3.eqGen
case class Person(name: String, age: Int)

enum Opt[+T] {
case Sm(t: T)
case Sm[U](t: U) extends Opt[U]
case Nn
}

enum OptInfer[+T] {
case Sm[+U](t: U) extends OptInfer[U]
case Nn
}

// simulation of Opt using case class hierarchy
sealed abstract class OptCase[+T]
object OptCase {
final case class Sm[T](t: T) extends OptCase[T]
case object Nn extends OptCase[Nothing]
}

@main def Test() = {
import Opt._
import Eq.{given _, _}
Expand All @@ -30,15 +42,23 @@ enum Opt[+T] {
println(t4) // false
println

val t5 = Sm(23) === Sm(23)
val t5 = Opt.Sm[Int](23) === Opt.Sm(23) // same behaviour as case class when using apply
println(t5) // true
println

val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23))
val t5_2 = OptCase.Sm[Int](23) === OptCase.Sm(23)
println(t5_2) // true
println

val t5_3 = OptInfer.Sm(23) === OptInfer.Sm(23) // covariant `Sm` case means we can avoid explicit type parameter
println(t5_3) // true
println

val t6 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 23))
println(t6) // true
println

val t7 = Sm(Person("Test", 23)) === Sm(Person("Test", 24))
val t7 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 24))
println(t7) // false
println
}
}
34 changes: 34 additions & 0 deletions tests/run/enum-nat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import Nat._
import compiletime._

enum Nat:
case Zero
case Succ[N <: Nat](n: N)

inline def toIntTypeLevel[N <: Nat]: Int = inline erasedValue[N] match
case _: Zero.type => 0
case _: Succ[n] => toIntTypeLevel[n] + 1

inline def toInt(inline nat: Nat): Int = inline nat match
case nat: Zero.type => 0
case nat: Succ[n] => toInt(nat.n) + 1

inline def toIntUnapply(inline nat: Nat): Int = inline nat match
case Zero => 0
case Succ(n) => toIntUnapply(n) + 1

inline def toIntTypeTailRec[N <: Nat, Acc <: Int]: Int = inline erasedValue[N] match
case _: Zero.type => constValue[Acc]
case _: Succ[n] => toIntTypeTailRec[n, S[Acc]]

inline def toIntErased[N <: Nat](inline nat: N): Int = toIntTypeTailRec[N, 0]

@main def Test: Unit =
println("erased value:")
assert(toIntTypeLevel[Succ[Succ[Succ[Zero.type]]]] == 3)
println("type test:")
assert(toInt(Succ(Succ(Succ(Zero)))) == 3)
println("unapply:")
assert(toIntUnapply(Succ(Succ(Succ(Zero)))) == 3)
println("infer erased:")
assert(toIntErased(Succ(Succ(Succ(Zero)))) == 3)
31 changes: 31 additions & 0 deletions tests/run/enum-precise.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
enum NonEmptyList[+T]:
case Many[+U](head: U, tail: NonEmptyList[U]) extends NonEmptyList[U]
case One [+U](value: U) extends NonEmptyList[U]

enum Ast:
case Binding(name: String, tpe: String)
case Lambda(args: NonEmptyList[Binding], rhs: Ast) // reference to another case of the enum
case Ident(name: String)
case Apply(fn: Ast, args: NonEmptyList[Ast])

import NonEmptyList._
import Ast._

// This example showcases the widening when inferring enum case types.
// With scala 2 case class hierarchies, if One.apply(1) returns One[Int] and Many.apply(2, One(3)) returns Many[Int]
// then the `foldRight` expression below would complain that Many[Binding] is not One[Binding]. With Scala 3 enums,
// .apply on the companion returns the precise class, but type inference will widen to NonEmptyList[Binding] unless
// the precise class is expected.
def Bindings(arg: (String, String), args: (String, String)*): NonEmptyList[Binding] =
def Bind(arg: (String, String)): Binding =
val (name, tpe) = arg
Binding(name, tpe)

args.foldRight(One[Binding](Bind(arg)))((arg, acc) => Many(Bind(arg), acc))

@main def Test: Unit =
val OneOfOne: One[1] = One[1](1)
val True = Lambda(Bindings("x" -> "T", "y" -> "T"), Ident("x"))
val Const = Lambda(One(Binding("x", "T")), Lambda(One(Binding("y", "U")), Ident("x"))) // precise type is forwarded

assert(OneOfOne.value == 1)