Skip to content

Commit 656304b

Browse files
committed
An alternative scheme for precise apply methods of enums
Keeps many elements from #9922 but the modality where we do the widening is different. The new rule is as follows: In an application of a compiler-generated apply or copy method of an enum case, widen its type to the underlying supertype of the enum case by means of a type ascription, unless the expected type is an enum case itself.
1 parent 09eaed7 commit 656304b

File tree

16 files changed

+165
-71
lines changed

16 files changed

+165
-71
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+2-14
Original file line numberDiff line numberDiff line change
@@ -658,15 +658,6 @@ object desugar {
658658
// For all other classes, the parent is AnyRef.
659659
val companions =
660660
if (isCaseClass) {
661-
// The return type of the `apply` method, and an (empty or singleton) list
662-
// of widening coercions
663-
val (applyResultTpt, widenDefs) =
664-
if (!isEnumCase)
665-
(TypeTree(), Nil)
666-
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
667-
(enumClassTypeRef, Nil)
668-
else
669-
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))
670661

671662
// true if access to the apply method has to be restricted
672663
// i.e. if the case class constructor is either private or qualified private
@@ -697,8 +688,6 @@ object desugar {
697688
then anyRef
698689
else
699690
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
700-
def widenedCreatorExpr =
701-
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
702691
val applyMeths =
703692
if (mods.is(Abstract)) Nil
704693
else {
@@ -711,9 +700,8 @@ object desugar {
711700
val appParamss =
712701
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
713702
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
714-
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
715-
.withMods(appMods)
716-
app :: widenDefs
703+
DefDef(nme.apply, derivedTparams, appParamss, TypeTree(), creatorExpr)
704+
.withMods(appMods) :: Nil
717705
}
718706
val unapplyMeth = {
719707
val hasRepeatedParam = constrVparamss.head.exists {

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

-46
Original file line numberDiff line numberDiff line change
@@ -201,52 +201,6 @@ object DesugarEnums {
201201
TypeTree(), creator).withFlags(Private | Synthetic)
202202
}
203203

204-
/** The return type of an enum case apply method and any widening methods in which
205-
* the apply's right hand side will be wrapped. For parents of the form
206-
*
207-
* extends E(args) with T1(args1) with ... TN(argsN)
208-
*
209-
* and type parameters `tparams` the generated widen method is
210-
*
211-
* def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1
212-
*
213-
* @param cdef The case definition
214-
* @param parents The declared parents of the enum case
215-
* @param tparams The type parameters of the enum case
216-
* @param appliedEnumRef The enum class applied to `tparams`.
217-
*/
218-
def enumApplyResult(
219-
cdef: TypeDef,
220-
parents: List[Tree],
221-
tparams: List[TypeDef],
222-
appliedEnumRef: Tree)(using Context): (Tree, List[DefDef]) = {
223-
224-
def extractType(t: Tree): Tree = t match {
225-
case Apply(t1, _) => extractType(t1)
226-
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
227-
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
228-
case New(t1) => t1
229-
case t1 => t1
230-
}
231-
232-
val parentTypes = parents.map(extractType)
233-
parentTypes.head match {
234-
case parent: RefTree if parent.name == enumClass.name =>
235-
// need a widen method to compute correct type parameters for enum base class
236-
val widenParamType = parentTypes.tail.foldLeft(appliedEnumRef)(makeAndType)
237-
val widenParam = makeSyntheticParameter(tpt = widenParamType)
238-
val widenDef = DefDef(
239-
name = s"${cdef.name}$$to$$${enumClass.name}".toTermName,
240-
tparams = tparams,
241-
vparamss = (widenParam :: Nil) :: Nil,
242-
tpt = TypeTree(),
243-
rhs = Ident(widenParam.name))
244-
(TypeTree(), widenDef :: Nil)
245-
case _ =>
246-
(parentTypes.reduceLeft(makeAndType), Nil)
247-
}
248-
}
249-
250204
/** Is a type parameter in `enumTypeParams` referenced from an enum class case that has
251205
* given type parameters `caseTypeParams`, value parameters `vparamss` and parents `parents`?
252206
* Issues an error if that is the case but the reference is illegal.

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

+3
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ object SymUtils {
160160
def isField(using Context): Boolean =
161161
self.isTerm && !self.is(Method)
162162

163+
def isEnumCase(using Context): Boolean =
164+
self.isAllOf(EnumCase, butNot = JavaDefined)
165+
163166
def annotationsCarrying(meta: ClassSymbol)(using Context): List[Annotation] =
164167
self.annotations.filter(_.symbol.hasAnnotation(meta))
165168

compiler/src/dotty/tools/dotc/typer/Applications.scala

+25-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ProtoTypes._
2525
import Inferencing._
2626
import reporting._
2727
import transform.TypeUtils._
28+
import transform.SymUtils._
2829
import Nullables.{postProcessByNameArgs, given _}
2930
import config.Feature
3031

@@ -891,7 +892,9 @@ trait Applications extends Compatibility {
891892
case funRef: TermRef =>
892893
val app = ApplyTo(tree, fun1, funRef, proto, pt)
893894
convertNewGenericArray(
894-
postProcessByNameArgs(funRef, app).computeNullable())
895+
widenEnumCase(
896+
postProcessByNameArgs(funRef, app).computeNullable(),
897+
pt))
895898
case _ =>
896899
handleUnexpectedFunType(tree, fun1)
897900
}
@@ -1091,7 +1094,7 @@ trait Applications extends Compatibility {
10911094
* It is performed during typer as creation of generic arrays needs a classTag.
10921095
* we rely on implicit search to find one.
10931096
*/
1094-
def convertNewGenericArray(tree: Tree)(using Context): Tree = tree match {
1097+
def convertNewGenericArray(tree: Tree)(using Context): Tree = tree match {
10951098
case Apply(TypeApply(tycon, targs@(targ :: Nil)), args) if tycon.symbol == defn.ArrayConstructor =>
10961099
fullyDefinedType(tree.tpe, "array", tree.span)
10971100

@@ -1107,6 +1110,26 @@ trait Applications extends Compatibility {
11071110
tree
11081111
}
11091112

1113+
/** If `tree` is a complete application of a compiler-generated `apply`
1114+
* or `copy` method of an enum case, widen its type to the underlying
1115+
* type by means of a type ascription, unless the expected type is an
1116+
* enum case itself.
1117+
* The underlying type is the intersection of all class parents of the
1118+
* orginal type.
1119+
*/
1120+
def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree =
1121+
val sym = tree.symbol
1122+
def isEnumCopy = sym.name == nme.copy && sym.owner.isEnumCase
1123+
def isEnumApply = sym.name == nme.apply && sym.owner.linkedClass.isEnumCase
1124+
if sym.is(Synthetic) && (isEnumApply || isEnumCopy)
1125+
&& tree.tpe.classSymbol.isEnumCase
1126+
&& !pt.isInstanceOf[FunProto]
1127+
&& !pt.classSymbol.isEnumCase
1128+
then
1129+
Typed(tree, TypeTree(tree.tpe.parents.reduceLeft(TypeComparer.andType(_, _))))
1130+
else
1131+
tree
1132+
11101133
/** Does `state` contain a "NotAMember" or "MissingIdent" message as
11111134
* first pending error message? That message would be
11121135
* `$memberName is not a member of ...` or `Not found: $memberName`.

compiler/src/dotty/tools/dotc/typer/ReTyper.scala

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ class ReTyper extends Typer with ReChecking {
132132
override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult =
133133
Implicits.NoMatchingImplicitsFailure
134134
override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = ()
135+
136+
override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree
137+
135138
override protected def addAccessorDefs(cls: Symbol, body: List[Tree])(using Context): List[Tree] = body
136139
override protected def checkEqualityEvidence(tree: tpd.Tree, pt: Type)(using Context): Unit = ()
137140
override protected def matchingApply(methType: MethodOrPoly, pt: FunProto)(using Context): Boolean = true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
scala> enum Maybe[+T] { case Something(value: T); case EmptyValue }; import Maybe._
2+
// defined class Maybe
3+
4+
scala> List(Something(1))
5+
val res0: List[Maybe[Int]] = List(Something(1))
6+
7+
scala> def listOfSomething[O <: Maybe.Something[_]](listOfSomething: List[O]): listOfSomething.type = listOfSomething
8+
def listOfSomething
9+
[O <: Maybe.Something[?]](listOfSomething: List[O]): listOfSomething.type
10+
11+
scala> listOfSomething(List(Something(1)))
12+
val res1: List[Maybe.Something[Int]] = List(Something(1))

tests/patmat/i7186.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -172,16 +172,16 @@ object printMips {
172172

173173
def oneAddr[O]
174174
( a: O, indent: String)
175-
( r: O => Dest,
175+
( r: a.type => Dest,
176176
): String = {
177177
val name = a.getClass.getSimpleName.toLowerCase
178178
s"${indent}$name ${rsrc(r(a))}$endl"
179179
}
180180

181181
def twoAddr[O]
182182
( a: O, indent: String)
183-
( d: O => Register,
184-
r: O => Dest | Constant
183+
( d: a.type => Register,
184+
r: a.type => Dest | Constant
185185
): String = {
186186
val name = a.getClass.getSimpleName.toLowerCase
187187
s"${indent}$name ${registers(d(a))}, ${rsrc(r(a))}$endl"
@@ -190,9 +190,9 @@ object printMips {
190190
def threeAddr[O]
191191
( a: O,
192192
indent: String )
193-
( d: O => Register,
194-
l: O => Register,
195-
r: O => Src
193+
( d: a.type => Register,
194+
l: a.type => Register,
195+
r: a.type => Src
196196
): String = {
197197
val name = a.getClass.getSimpleName.toLowerCase
198198
s"${indent}$name ${registers(d(a))}, ${registers(l(a))}, ${rsrc(r(a))}$endl"

tests/pos/enum-widen.scala

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object test:
2+
3+
enum Option[+T]:
4+
case Some[T](x: T) extends Option[T]
5+
case None
6+
7+
import Option._
8+
9+
var x = Some(1)
10+
val y: Some[Int] = Some(2)
11+
var xc = y.copy(3)
12+
val yc: Some[Int] = y.copy(3)
13+
x = None
14+
xc = None
15+
16+

tests/pos/i3935.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
enum Foo3[T](x: T) {
2+
case Bar[S, T](y: T) extends Foo3[y.type](y)
3+
}
4+
5+
val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
6+
val bar = foo
7+
8+
def baz[T](f: Foo3[T]): f.type = f
9+
10+
val qux = baz(bar) // existentials are back in Dotty?
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import Nat._
2+
3+
inline def toIntMacro(inline nat: Nat): Int = ${ Macros.toIntImpl('nat) }
4+
inline def ZeroMacro: Zero.type = ${ Macros.natZero }
5+
transparent inline def toNatMacro(inline int: Int): Nat = ${ Macros.toNatImpl('int) }
6+
7+
object Macros:
8+
import quoted._
9+
10+
def toIntImpl(nat: Expr[Nat])(using QuoteContext): Expr[Int] =
11+
12+
def inner(nat: Expr[Nat], acc: Int): Int = nat match
13+
case '{ Succ($nat) } => inner(nat, acc + 1)
14+
case '{ Zero } => acc
15+
16+
Expr(inner(nat, 0))
17+
18+
def natZero(using QuoteContext): Expr[Nat.Zero.type] = '{Zero}
19+
20+
def toNatImpl(int: Expr[Int])(using QuoteContext): Expr[Nat] =
21+
22+
// it seems even with the bound that the arg will always widen to Expr[Nat] unless explicit
23+
24+
def inner[N <: Nat: Type](int: Int, acc: Expr[N]): Expr[Nat] = int match
25+
case 0 => acc
26+
case n => inner[Succ[N]](n - 1, '{Succ($acc)})
27+
28+
val Const(i) = int
29+
require(i >= 0)
30+
inner[Zero.type](i, '{Zero})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
enum Nat:
2+
case Zero
3+
case Succ[N <: Nat](n: N)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import Nat._
2+
3+
@main def Test: Unit =
4+
assert(toIntMacro(Succ(Succ(Succ(Zero)))) == 3)
5+
assert(toNatMacro(3) == Succ(Succ(Succ(Zero))))
6+
val zero: Zero.type = ZeroMacro
7+
assert(zero == Zero)
8+
assert(toIntMacro(toNatMacro(3)) == 3)
9+
val n: Succ[Succ[Succ[Zero.type]]] = toNatMacro(3)

tests/run-macros/i8007.check

+2
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,7 @@ true
1111

1212
true
1313

14+
true
15+
1416
false
1517

tests/run-macros/i8007/Macro_3.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ object Eq {
7373
}
7474

7575
object Macro3 {
76-
extension [T](x: =>T) inline def === (y: =>T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)
76+
extension [T](inline x: T) inline def === (inline y: T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)
7777

7878
implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
79-
}
79+
}

tests/run-macros/i8007/Test_4.scala

+11-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import Macro3.eqGen
66
case class Person(name: String, age: Int)
77

88
enum Opt[+T] {
9-
case Sm[T](t: T) extends Opt[T]
9+
case Sm(t: T)
10+
case Nn
11+
}
12+
13+
enum OptInv[+T] {
14+
case Sm[T](t: T) extends OptInv[T]
1015
case Nn
1116
}
1217

@@ -34,6 +39,11 @@ enum Opt[+T] {
3439
println(t5) // true
3540
println
3641

42+
// Here invariant case without explicit type parameter will instantiate T to OptInv[Any]
43+
val t5_2 = OptInv.Sm[Int](23) === OptInv.Sm(23)
44+
println(t5) // true
45+
println
46+
3747
val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23))
3848
println(t6) // true
3949
println

tests/run/enums-precise.scala

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
enum NonEmptyList[+T]:
2+
case Many[+U](head: U, tail: NonEmptyList[U]) extends NonEmptyList[U]
3+
case One [+U](value: U) extends NonEmptyList[U]
4+
5+
enum Ast:
6+
case Binding(name: String, tpe: String)
7+
case Lambda(args: NonEmptyList[Binding], rhs: Ast) // reference to another case of the enum
8+
case Ident(name: String)
9+
case Apply(fn: Ast, args: NonEmptyList[Ast])
10+
11+
import NonEmptyList._
12+
import Ast._
13+
14+
// This example showcases the widening when inferring enum case types.
15+
// With scala 2 case class hierarchies, if One.apply(1) returns One[Int] and Many.apply(2, One(3)) returns Many[Int]
16+
// then the `foldRight` expression below would complain that Many[Binding] is not One[Binding]. With Scala 3 enums,
17+
// .apply on the companion returns the precise class, but type inference will widen to NonEmptyList[Binding] unless
18+
// the precise class is expected.
19+
def Bindings(arg: (String, String), args: (String, String)*): NonEmptyList[Binding] =
20+
def Bind(arg: (String, String)): Binding =
21+
val (name, tpe) = arg
22+
Binding(name, tpe)
23+
24+
args.foldRight(One[Binding](Bind(arg)))((arg, acc) => Many(Bind(arg), acc))
25+
26+
@main def Test: Unit =
27+
val OneOfOne: One[1] = One[1](1)
28+
val True = Lambda(Bindings("x" -> "T", "y" -> "T"), Ident("x"))
29+
val Const = Lambda(One(Binding("x", "T")), Lambda(One(Binding("y", "U")), Ident("x"))) // precise type is forwarded
30+
31+
assert(OneOfOne.value == 1)

0 commit comments

Comments
 (0)