Skip to content

Minimal enums with precise apply methods #9922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ object desugar {
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree))
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, classTypeRef, creatorExpr)
.withMods(Modifiers(Synthetic | constr1.mods.flags & copiedAccessFlags, constr1.mods.privateWithin)) :: Nil
}
}
Expand Down Expand Up @@ -658,15 +658,6 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

// true if access to the apply method has to be restricted
// i.e. if the case class constructor is either private or qualified private
Expand Down Expand Up @@ -697,8 +688,6 @@ object desugar {
then anyRef
else
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods.is(Abstract)) Nil
else {
Expand All @@ -711,9 +700,8 @@ object desugar {
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
DefDef(nme.apply, derivedTparams, appParamss, classTypeRef, creatorExpr)
.withMods(appMods) :: Nil
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
Expand All @@ -722,7 +710,7 @@ object desugar {
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
val unapplyResTp = if arity == 0 then Literal(Constant(true)) else classTypeRef
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, unapplyResTp, unapplyRHS)
.withMods(synthetic)
}
Expand Down
46 changes: 0 additions & 46 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,52 +201,6 @@ object DesugarEnums {
TypeTree(), creator).withFlags(Private | Synthetic)
}

/** The return type of an enum case apply method and any widening methods in which
* the apply's right hand side will be wrapped. For parents of the form
*
* extends E(args) with T1(args1) with ... TN(argsN)
*
* and type parameters `tparams` the generated widen method is
*
* def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1
*
* @param cdef The case definition
* @param parents The declared parents of the enum case
* @param tparams The type parameters of the enum case
* @param appliedEnumRef The enum class applied to `tparams`.
*/
def enumApplyResult(
cdef: TypeDef,
parents: List[Tree],
tparams: List[TypeDef],
appliedEnumRef: Tree)(using Context): (Tree, List[DefDef]) = {

def extractType(t: Tree): Tree = t match {
case Apply(t1, _) => extractType(t1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice that we can simplify desugaring in a significant way.

case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
case New(t1) => t1
case t1 => t1
}

val parentTypes = parents.map(extractType)
parentTypes.head match {
case parent: RefTree if parent.name == enumClass.name =>
// need a widen method to compute correct type parameters for enum base class
val widenParamType = parentTypes.tail.foldLeft(appliedEnumRef)(makeAndType)
val widenParam = makeSyntheticParameter(tpt = widenParamType)
val widenDef = DefDef(
name = s"${cdef.name}$$to$$${enumClass.name}".toTermName,
tparams = tparams,
vparamss = (widenParam :: Nil) :: Nil,
tpt = TypeTree(),
rhs = Ident(widenParam.name))
(TypeTree(), widenDef :: Nil)
case _ =>
(parentTypes.reduceLeft(makeAndType), Nil)
}
}

/** Is a type parameter in `enumTypeParams` referenced from an enum class case that has
* given type parameters `caseTypeParams`, value parameters `vparamss` and parents `parents`?
* Issues an error if that is the case but the reference is illegal.
Expand Down
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,24 @@ trait ConstraintHandling {
val tpw = tp.widenSingletons
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenEnum(tp: Type) =
val tpw = tp.widenEnumCase
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def isSingleton(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)

def isEnumCase(tp: Type): Boolean = tp match
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to extract the commonalities of isEnumCase and isSingleton into a method

def widenWildcard(tp: Type) = tp match
  case WildcardType(optBounds) => optBounds
  case _ => tp

case WildcardType(optBounds) => optBounds.exists && isEnumCase(optBounds.bounds.hi)
case _ => tp.classSymbol.isAllOf(EnumCase, butNot=JavaDefined)

val wideInst =
if isSingleton(bound) then inst
else dropSuperTraits(widenOr(widenSingle(inst)))
else
val lub = widenOr(widenSingle(inst))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nicer if widenEnum simply returns the argument type if it is not an enum case. Then we could just use

dropSuperTraits(widenEnum(widenOr(widenSingle(inst))))

here

val asAdt = if isEnumCase(bound) then lub else widenEnum(lub)
dropSuperTraits(asAdt)
wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,11 @@ object Types {
case _ => this
}

/** if this type is a reference to a class case of an enum, replace it by its first parent */
final def widenEnumCase(using Context): Type = this match
case tp: (TypeRef | AppliedType) if tp.classSymbol.isAllOf(EnumCase) => tp.parents.head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the restriction to TypeRefs and AppliedTypes? What should happen with TypeVars, annotated types, or aliases?
A more pervasive alternative would be:

if tp.classSymbol.isAllOf(EnumCase) then tp.parents.head else tp

case _ => this

/** Widen this type and if the result contains embedded union types, replace
* them by their joins.
* "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Scanners {

object IndentWidth {
private inline val MaxCached = 40
private val spaces = Array.tabulate(MaxCached + 1)(new Run(' ', _))
private val tabs = Array.tabulate(MaxCached + 1)(new Run('\t', _))
private val spaces = Array.tabulate[Run](MaxCached + 1)(new Run(' ', _)) // TODO: remove new after bootstrap
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this gives me pause. Is this really what we want? I find it surprising and unnatural that one needs to add [Run] here. /cc @smarter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative design would be more local: Instead of widening enum cases everywhere where they take part in inference we only decide at the application itself. I.e. new Run(...) is always a Run and will stay that way. Run(...) is of type IndentWidth unless its expected type is Run. We should evaluate which scheme is preferable. My tendency is to prefer the alternative.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want an Array of Run you can write explicitly val spaces: Array[Run] = Array.tabulate(...)(Run(...)), that sounds OK to me and not much longer than adding a new.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I mean the worrying thing is that the Run type is in a sense provisional, and can be revoked at each type inference step. That's analogous to union and singleton types of course. But I would argue that the behavior in these cases can also be suprising and it's a necessary tradeoff. Only for enum the tradeoff does not look to be so necessary or natural.

private val tabs = Array.tabulate[Run](MaxCached + 1)(new Run('\t', _)) // TODO: remove new after bootstrap

def Run(ch: Char, n: Int): Run =
if (n <= MaxCached && ch == ' ') spaces(n)
Expand Down
12 changes: 12 additions & 0 deletions compiler/test-resources/type-printer/enum-precise
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
scala> enum Maybe[+T] { case Something(value: T); case EmptyValue }; import Maybe._
// defined class Maybe

scala> List(Something(1))
val res0: List[Maybe[Int]] = List(Something(1))

scala> def listOfSomething[O <: Maybe.Something[_]](listOfSomething: List[O]): listOfSomething.type = listOfSomething
def listOfSomething
[O <: Maybe.Something[?]](listOfSomething: List[O]): listOfSomething.type

scala> listOfSomething(List(Something(1)))
val res1: List[Maybe.Something[Int]] = List(Something(1))
12 changes: 6 additions & 6 deletions tests/patmat/i7186.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,16 @@ object printMips {

def oneAddr[O]
( a: O, indent: String)
( r: O => Dest,
( r: a.type => Dest,
): String = {
val name = a.getClass.getSimpleName.toLowerCase
s"${indent}$name ${rsrc(r(a))}$endl"
}

def twoAddr[O]
( a: O, indent: String)
( d: O => Register,
r: O => Dest | Constant
( d: a.type => Register,
r: a.type => Dest | Constant
): String = {
val name = a.getClass.getSimpleName.toLowerCase
s"${indent}$name ${registers(d(a))}, ${rsrc(r(a))}$endl"
Expand All @@ -190,9 +190,9 @@ object printMips {
def threeAddr[O]
( a: O,
indent: String )
( d: O => Register,
l: O => Register,
r: O => Src
( d: a.type => Register,
l: a.type => Register,
r: a.type => Src
): String = {
val name = a.getClass.getSimpleName.toLowerCase
s"${indent}$name ${registers(d(a))}, ${registers(l(a))}, ${rsrc(r(a))}$endl"
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/i3935.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
enum Foo3[T](x: T) {
case Bar[S, T](y: T) extends Foo3[y.type](y)
}

val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
val bar = foo

def baz[T](f: Foo3[T]): f.type = f

val qux = baz(bar) // existentials are back in Dotty?
30 changes: 30 additions & 0 deletions tests/run-macros/enum-nat-macro/Macros_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import Nat._

inline def toIntMacro(inline nat: Nat): Int = ${ Macros.toIntImpl('nat) }
inline def ZeroMacro: Zero.type = ${ Macros.natZero }
transparent inline def toNatMacro(inline int: Int): Nat = ${ Macros.toNatImpl('int) }

object Macros:
import quoted._

def toIntImpl(nat: Expr[Nat])(using QuoteContext): Expr[Int] =

def inner(nat: Expr[Nat], acc: Int): Int = nat match
case '{ Succ($nat) } => inner(nat, acc + 1)
case '{ Zero } => acc

Expr(inner(nat, 0))

def natZero(using QuoteContext): Expr[Nat.Zero.type] = '{Zero}

def toNatImpl(int: Expr[Int])(using QuoteContext): Expr[Nat] =

// it seems even with the bound that the arg will always widen to Expr[Nat] unless explicit

def inner[N <: Nat: Type](int: Int, acc: Expr[N]): Expr[Nat] = int match
case 0 => acc
case n => inner[Succ[N]](n - 1, '{Succ($acc)})

val Const(i) = int
require(i >= 0)
inner[Zero.type](i, '{Zero})
3 changes: 3 additions & 0 deletions tests/run-macros/enum-nat-macro/Nat_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
enum Nat:
case Zero
case Succ[N <: Nat](n: N)
9 changes: 9 additions & 0 deletions tests/run-macros/enum-nat-macro/Test_3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import Nat._

@main def Test: Unit =
assert(toIntMacro(Succ(Succ(Succ(Zero)))) == 3)
assert(toNatMacro(3) == Succ(Succ(Succ(Zero))))
val zero: Zero.type = ZeroMacro
assert(zero == Zero)
assert(toIntMacro(toNatMacro(3)) == 3)
val n: Succ[Succ[Succ[Zero.type]]] = toNatMacro(3)
2 changes: 2 additions & 0 deletions tests/run-macros/i8007.check
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ true

true

true

false

4 changes: 2 additions & 2 deletions tests/run-macros/i8007/Macro_3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object Eq {
}

object Macro3 {
extension [T](x: =>T) inline def === (y: =>T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)
extension [T](inline x: T) inline def === (inline y: T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)

implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
}
}
12 changes: 11 additions & 1 deletion tests/run-macros/i8007/Test_4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ import Macro3.eqGen
case class Person(name: String, age: Int)

enum Opt[+T] {
case Sm[T](t: T) extends Opt[T]
case Sm(t: T)
case Nn
}

enum OptInv[+T] {
case Sm[T](t: T) extends OptInv[T]
case Nn
}

Expand Down Expand Up @@ -34,6 +39,11 @@ enum Opt[+T] {
println(t5) // true
println

// Here invariant case without explicit type parameter will instantiate T to OptInv[Any]
val t5_2 = OptInv.Sm[Int](23) === OptInv.Sm(23)
println(t5) // true
println

val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23))
println(t6) // true
println
Expand Down
31 changes: 31 additions & 0 deletions tests/run/enums-precise.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
enum NonEmptyList[+T]:
case Many[+U](head: U, tail: NonEmptyList[U]) extends NonEmptyList[U]
case One [+U](value: U) extends NonEmptyList[U]

enum Ast:
case Binding(name: String, tpe: String)
case Lambda(args: NonEmptyList[Binding], rhs: Ast) // reference to another case of the enum
case Ident(name: String)
case Apply(fn: Ast, args: NonEmptyList[Ast])

import NonEmptyList._
import Ast._

// This example showcases the widening when inferring enum case types.
// With scala 2 case class hierarchies, if One.apply(1) returns One[Int] and Many.apply(2, One(3)) returns Many[Int]
// then the `foldRight` expression below would complain that Many[Binding] is not One[Binding]. With Scala 3 enums,
// .apply on the companion returns the precise class, but type inference will widen to NonEmptyList[Binding] unless
// the precise class is expected.
def Bindings(arg: (String, String), args: (String, String)*): NonEmptyList[Binding] =
def Bind(arg: (String, String)): Binding =
val (name, tpe) = arg
Binding(name, tpe)

args.foldRight(One[Binding](Bind(arg)))((arg, acc) => Many(Bind(arg), acc))

@main def Test: Unit =
val OneOfOne: One[1] = One[1](1)
val True = Lambda(Bindings("x" -> "T", "y" -> "T"), Ident("x"))
val Const = Lambda(One(Binding("x", "T")), Lambda(One(Binding("y", "U")), Ident("x"))) // precise type is forwarded

assert(OneOfOne.value == 1)