Skip to content

An alternative scheme for precise apply methods of enums #9932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -658,15 +658,6 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

// true if access to the apply method has to be restricted
// i.e. if the case class constructor is either private or qualified private
Expand Down Expand Up @@ -697,8 +688,6 @@ object desugar {
then anyRef
else
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods.is(Abstract)) Nil
else {
Expand All @@ -711,9 +700,8 @@ object desugar {
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
DefDef(nme.apply, derivedTparams, appParamss, TypeTree(), creatorExpr)
.withMods(appMods) :: Nil
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
Expand Down
46 changes: 0 additions & 46 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,52 +201,6 @@ object DesugarEnums {
TypeTree(), creator).withFlags(Private | Synthetic)
}

/** The return type of an enum case apply method and any widening methods in which
* the apply's right hand side will be wrapped. For parents of the form
*
* extends E(args) with T1(args1) with ... TN(argsN)
*
* and type parameters `tparams` the generated widen method is
*
* def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1
*
* @param cdef The case definition
* @param parents The declared parents of the enum case
* @param tparams The type parameters of the enum case
* @param appliedEnumRef The enum class applied to `tparams`.
*/
def enumApplyResult(
cdef: TypeDef,
parents: List[Tree],
tparams: List[TypeDef],
appliedEnumRef: Tree)(using Context): (Tree, List[DefDef]) = {

def extractType(t: Tree): Tree = t match {
case Apply(t1, _) => extractType(t1)
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
case New(t1) => t1
case t1 => t1
}

val parentTypes = parents.map(extractType)
parentTypes.head match {
case parent: RefTree if parent.name == enumClass.name =>
// need a widen method to compute correct type parameters for enum base class
val widenParamType = parentTypes.tail.foldLeft(appliedEnumRef)(makeAndType)
val widenParam = makeSyntheticParameter(tpt = widenParamType)
val widenDef = DefDef(
name = s"${cdef.name}$$to$$${enumClass.name}".toTermName,
tparams = tparams,
vparamss = (widenParam :: Nil) :: Nil,
tpt = TypeTree(),
rhs = Ident(widenParam.name))
(TypeTree(), widenDef :: Nil)
case _ =>
(parentTypes.reduceLeft(makeAndType), Nil)
}
}

/** Is a type parameter in `enumTypeParams` referenced from an enum class case that has
* given type parameters `caseTypeParams`, value parameters `vparamss` and parents `parents`?
* Issues an error if that is the case but the reference is illegal.
Expand Down
83 changes: 42 additions & 41 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,53 @@ trait ConstraintHandling {
}
}

/** If `tp` is an intersection such that some operands are super trait instances
* and others are not, replace as many super trait instances as possible with Any
* as long as the result is still a subtype of `bound`. But fall back to the
* original type if the resulting widened type is a supertype of all dropped
* types (since in this case the type was not a true intersection of super traits
* and other types to start with).
*/
def dropSuperTraits(tp: Type, bound: Type)(using Context): Type =
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
var dropped: List[Type] = List() // the types dropped so far, last one on top

def dropOneSuperTrait(tp: Type): Type =
val tpd = tp.dealias
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
dropped = tpd :: dropped
defn.AnyType
else tpd match
case AndType(tp1, tp2) =>
val tp1w = dropOneSuperTrait(tp1)
if tp1w ne tp1 then tp1w & tp2
else
val tp2w = dropOneSuperTrait(tp2)
if tp2w ne tp2 then tp1 & tp2w
else tpd
case _ =>
tp

def recur(tp: Type): Type =
val tpw = dropOneSuperTrait(tp)
if tpw eq tp then tp
else if tpw <:< bound then recur(tpw)
else
kept += dropped.head
dropped = dropped.tail
recur(tp)

val tpw = recur(tp)
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
end dropSuperTraits

/** Widen inferred type `inst` with upper `bound`, according to the following rules:
* 1. If `inst` is a singleton type, or a union containing some singleton types,
* widen (all) the singleton type(s), provided the result is a subtype of `bound`
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
* 2. If `inst` is a union type, approximate the union type from above by an intersection
* of all common base types, provided the result is a subtype of `bound`.
* 3. If `inst` is an intersection such that some operands are super trait instances
* and others are not, replace as many super trait instances as possible with Any
* as long as the result is still a subtype of `bound`. But fall back to the
* original type if the resulting widened type is a supertype of all dropped
* types (since in this case the type was not a true intersection of super traits
* and other types to start with).
* 3. drop super traits from intersections (see @dropSuperTraits)
*
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
* Also, if the result of these widenings is a TypeRef to a module class,
Expand All @@ -308,40 +343,6 @@ trait ConstraintHandling {
* as those could leak the annotation to users (see run/inferred-repeated-result).
*/
def widenInferred(inst: Type, bound: Type)(using Context): Type =

def dropSuperTraits(tp: Type): Type =
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
var dropped: List[Type] = List() // the types dropped so far, last one on top

def dropOneSuperTrait(tp: Type): Type =
val tpd = tp.dealias
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
dropped = tpd :: dropped
defn.AnyType
else tpd match
case AndType(tp1, tp2) =>
val tp1w = dropOneSuperTrait(tp1)
if tp1w ne tp1 then tp1w & tp2
else
val tp2w = dropOneSuperTrait(tp2)
if tp2w ne tp2 then tp1 & tp2w
else tpd
case _ =>
tp

def recur(tp: Type): Type =
val tpw = dropOneSuperTrait(tp)
if tpw eq tp then tp
else if tpw <:< bound then recur(tpw)
else
kept += dropped.head
dropped = dropped.tail
recur(tp)

val tpw = recur(tp)
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
end dropSuperTraits

def widenOr(tp: Type) =
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
Expand All @@ -356,7 +357,7 @@ trait ConstraintHandling {

val wideInst =
if isSingleton(bound) then inst
else dropSuperTraits(widenOr(widenSingle(inst)))
else dropSuperTraits(widenOr(widenSingle(inst)), bound)
wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2625,6 +2625,9 @@ object TypeComparer {
def widenInferred(inst: Type, bound: Type)(using Context): Type =
comparing(_.widenInferred(inst, bound))

def dropSuperTraits(tp: Type, bound: Type)(using Context): Type =
comparing(_.dropSuperTraits(tp, bound))

def constrainPatternType(pat: Type, scrut: Type)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut))

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ object SymUtils {
def isField(using Context): Boolean =
self.isTerm && !self.is(Method)

def isEnumCase(using Context): Boolean =
self.isAllOf(EnumCase, butNot = JavaDefined)

def annotationsCarrying(meta: ClassSymbol)(using Context): List[Annotation] =
self.annotations.filter(_.symbol.hasAnnotation(meta))

Expand Down
29 changes: 27 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import ProtoTypes._
import Inferencing._
import reporting._
import transform.TypeUtils._
import transform.SymUtils._
import Nullables.{postProcessByNameArgs, given _}
import config.Feature

Expand Down Expand Up @@ -891,7 +892,9 @@ trait Applications extends Compatibility {
case funRef: TermRef =>
val app = ApplyTo(tree, fun1, funRef, proto, pt)
convertNewGenericArray(
postProcessByNameArgs(funRef, app).computeNullable())
widenEnumCase(
postProcessByNameArgs(funRef, app).computeNullable(),
pt))
case _ =>
handleUnexpectedFunType(tree, fun1)
}
Expand Down Expand Up @@ -1091,7 +1094,7 @@ trait Applications extends Compatibility {
* It is performed during typer as creation of generic arrays needs a classTag.
* we rely on implicit search to find one.
*/
def convertNewGenericArray(tree: Tree)(using Context): Tree = tree match {
def convertNewGenericArray(tree: Tree)(using Context): Tree = tree match {
case Apply(TypeApply(tycon, targs@(targ :: Nil)), args) if tycon.symbol == defn.ArrayConstructor =>
fullyDefinedType(tree.tpe, "array", tree.span)

Expand All @@ -1107,6 +1110,28 @@ trait Applications extends Compatibility {
tree
}

/** If `tree` is a complete application of a compiler-generated `apply`
* or `copy` method of an enum case, widen its type to the underlying
* type by means of a type ascription, as long as the widened type is
* still compatible with the expected type.
* The underlying type is the intersection of all class parents of the
* orginal type.
*/
def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree =
val sym = tree.symbol
def isEnumCopy = sym.name == nme.copy && sym.owner.isEnumCase
def isEnumApply = sym.name == nme.apply && sym.owner.linkedClass.isEnumCase
if sym.is(Synthetic) && (isEnumApply || isEnumCopy)
&& tree.tpe.classSymbol.isEnumCase
&& tree.tpe.widen.isValueType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this check needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that we don't widen partially applied applications.

then
val widened = TypeComparer.dropSuperTraits(
tree.tpe.parents.reduceLeft(TypeComparer.andType(_, _)),
pt)
if widened <:< pt then Typed(tree, TypeTree(widened))
else tree
else tree

/** Does `state` contain a "NotAMember" or "MissingIdent" message as
* first pending error message? That message would be
* `$memberName is not a member of ...` or `Not found: $memberName`.
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class ReTyper extends Typer with ReChecking {
override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult =
Implicits.NoMatchingImplicitsFailure
override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = ()

override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree

override protected def addAccessorDefs(cls: Symbol, body: List[Tree])(using Context): List[Tree] = body
override protected def checkEqualityEvidence(tree: tpd.Tree, pt: Type)(using Context): Unit = ()
override protected def matchingApply(methType: MethodOrPoly, pt: FunProto)(using Context): Boolean = true
Expand Down
10 changes: 4 additions & 6 deletions docs/docs/reference/enums/adts.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,13 @@ scala> Option.None
val res2: t2.Option[Nothing] = None
```

Note that the type of the expressions above is always `Option`. That
is, the implementation case classes are not visible in the result
types of their `apply` methods. This is a subtle difference with
respect to normal case classes. The classes making up the cases do
exist, and can be unveiled by constructing them directly with a `new`.
Note that the type of the expressions above is always `Option`. Generally, the type of a enum case constructor application will be widened to the underlying enum type, unless a more specific type is expected. This is a subtle difference with respect to normal case classes. The classes making up the cases do exist, and can be unveiled, either by constructing them directly with a `new`, or by explicitly providing an expected type.

```scala
scala> new Option.Some(2)
val res3: t2.Option.Some[Int] = Some(2)
val res3: Option.Some[Int] = Some(2)
scala> val x: Option.Some[Int] = Option.Some(3)
val res4: Option.Some[Int] = Some(3)
```

As all other enums, ADTs can define methods. For instance, here is `Option` again, with an
Expand Down
13 changes: 9 additions & 4 deletions docs/docs/reference/enums/desugarEnums.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ map into `case class`es or `val`s.
```scala
final case class C <params> extends <parents>
```
However, unlike for a regular case class, the return type of the associated
`apply` method is a fully parameterized type instance of the enum class `E`
itself instead of `C`. Also the enum case defines an `ordinal` method of
the form
The enum case defines an `ordinal` method of the form
```scala
def ordinal = n
```
Expand All @@ -153,6 +150,14 @@ map into `case class`es or `val`s.
in a parameter type in `<params>` or in a type argument of `<parents>`, unless that parameter is already
a type parameter of the case, i.e. the parameter name is defined in `<params>`.

The compiler-generated `apply` and `copy` methods of an enum case
```scala
case C(ps) extends P1, ..., Pn
```
are treated specially. A call `C(ts)` of the apply method is ascribed the underlying type
`P1 & ... & Pn` (dropping any [super traits](../other-new-features/super-traits.html))
as long as that type is still compatible with the expected type at the point of application.
A call `t.copy(ts)` of `C`'s `copy` method is treated in the same way.

### Translation of Enums with Singleton Cases

Expand Down
21 changes: 21 additions & 0 deletions tests/pos/enum-widen.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
object test:

enum Option[+T]:
case Some[T](x: T) extends Option[T]
case None

import Option._

var x = Some(1)
val y: Some[Int] = Some(2)
var xc = y.copy(3)
val yc: Some[Int] = y.copy(3)
x = None
xc = None

enum Nat:
case Z
case S[N <: Z.type | S[_]](pred: N)
import Nat._

val two = S(S(Z))
10 changes: 10 additions & 0 deletions tests/pos/i3935.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
enum Foo3[T](x: T) {
case Bar[S, T](y: T) extends Foo3[y.type](y)
}

val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
val bar = foo

def baz[T](f: Foo3[T]): f.type = f

val qux = baz(bar) // existentials are back in Dotty?
30 changes: 30 additions & 0 deletions tests/run-macros/enum-nat-macro/Macros_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import Nat._

inline def toIntMacro(inline nat: Nat): Int = ${ Macros.toIntImpl('nat) }
inline def ZeroMacro: Zero.type = ${ Macros.natZero }
transparent inline def toNatMacro(inline int: Int): Nat = ${ Macros.toNatImpl('int) }

object Macros:
import quoted._

def toIntImpl(nat: Expr[Nat])(using QuoteContext): Expr[Int] =

def inner(nat: Expr[Nat], acc: Int): Int = nat match
case '{ Succ($nat) } => inner(nat, acc + 1)
case '{ Zero } => acc

Expr(inner(nat, 0))

def natZero(using QuoteContext): Expr[Nat.Zero.type] = '{Zero}

def toNatImpl(int: Expr[Int])(using QuoteContext): Expr[Nat] =

// it seems even with the bound that the arg will always widen to Expr[Nat] unless explicit

def inner[N <: Nat: Type](int: Int, acc: Expr[N]): Expr[Nat] = int match
case 0 => acc
case n => inner[Succ[N]](n - 1, '{Succ($acc)})

val Const(i) = int
require(i >= 0)
inner[Zero.type](i, '{Zero})
Loading