Skip to content

Commit 6fce4ca

Browse files
committed
add version of Nat to int using enums
1 parent 4ec5446 commit 6fce4ca

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,11 @@ trait ConstraintHandling {
354354
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
355355
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)
356356

357+
def isEnumBound(tp: Type): Boolean =
358+
bound.isValueType && bound.typeSymbol.is(Enum, butNot=JavaDefined) && tp <:< bound
359+
357360
val wideInst =
358-
if isSingleton(bound) then inst
361+
if isSingleton(bound) || isEnumBound(bound) then inst
359362
else dropSuperTraits(widenOr(widenSingle(inst)))
360363
wideInst match
361364
case wideInst: TypeRef if wideInst.symbol.is(Module) =>

tests/neg/enum-widen.scala

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import Nat._
2+
3+
enum Nat:
4+
case Zero
5+
case Succ[N <: Nat](n: N)
6+
7+
enum MyEnum:
8+
case A[E <: Enum](e: E)
9+
10+
final case class Foo[T](t: T)
11+
12+
inline def willNotReduce1 = inline Foo(Zero) match // assert that enums are widened when the bound is not a parent enum type
13+
case Foo(Zero) => ()
14+
15+
inline def willNotReduce2 = inline MyEnum.A(Zero) match // assert that enums are only narrowed when bound is own enum type
16+
case MyEnum.A(Zero) => ()
17+
18+
val foo = willNotReduce1 // error: cannot reduce inline match with scrutinee: Foo.apply[Nat](Nat$#Zero) : Foo[Nat]
19+
val bar = willNotReduce2 // error: cannot reduce inline match with scrutinee: MyEnum.A.apply[Nat](Nat$#Zero): MyEnum.A[Nat]

tests/run/enum-nat.scala

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import Nat._
2+
import compiletime._
3+
4+
enum Nat:
5+
case Zero
6+
case Succ[N <: Nat](n: N)
7+
8+
inline def toIntTypeLevel[N <: Nat]: Int = inline erasedValue[N] match
9+
case _: Zero.type => 0
10+
case _: Succ[n] => toIntTypeLevel[n] + 1
11+
12+
inline def toInt(inline nat: Nat): Int = inline nat match
13+
case nat: Zero.type => 0
14+
case nat: Succ[n] => toInt(nat.n) + 1
15+
16+
inline def toIntTypeTailRec[N <: Nat, Acc <: Int]: Int = inline erasedValue[N] match
17+
case _: Zero.type => constValue[Acc]
18+
case _: Succ[n] => toIntTypeTailRec[n, S[Acc]]
19+
20+
inline def toIntErased[N <: Nat](inline nat: N): Int = toIntTypeTailRec[N, 0]
21+
22+
@main def Test: Unit =
23+
assert(toIntTypeLevel[Succ[Succ[Zero.type]]] == 2)
24+
assert(toInt(Succ(Succ(Zero))) == 2)
25+
assert(toIntErased(Succ(Succ(Zero))) == 2)

0 commit comments

Comments
 (0)