Skip to content

Commit 9e6b993

Browse files
LucySMartinLucy Martin
authored and
Lucy Martin
committed
First pass of combining multiple union types
1 parent a6c40b1 commit 9e6b993

File tree

3 files changed

+114
-7
lines changed

3 files changed

+114
-7
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,20 +3414,85 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
34143414

34153415
def typedType(tree: untpd.Tree, pt: Type = WildcardType, mapPatternBounds: Boolean = false)(using Context): Tree =
34163416
val tree1 = withMode(Mode.Type) { typed(tree, pt) }
3417+
val tree2 = tree1 match
3418+
case inferredTree: InferredTypeTree if inferredTree.hasType =>
3419+
inferredTree.tpe match
3420+
case or: OrType =>
3421+
val res = flattenOr(or)
3422+
tree1.withType(res)
3423+
case _ =>
3424+
tree1
3425+
case _ =>
3426+
tree1
34173427
if mapPatternBounds && ctx.mode.is(Mode.Pattern) && !ctx.isAfterTyper then
3418-
tree1 match
3419-
case tree1: TypeBoundsTree =>
3428+
tree2 match
3429+
case tree2: TypeBoundsTree =>
34203430
// Associate a pattern-bound type symbol with the wildcard.
34213431
// The bounds of the type symbol can be constrained when comparing a pattern type
34223432
// with an expected type in typedTyped. The type symbol and the defining Bind node
34233433
// are eliminated once the enclosing pattern has been typechecked; see `indexPattern`
34243434
// in `typedCase`.
34253435
val boundName = WildcardParamName.fresh().toTypeName
3426-
val wildcardSym = newPatternBoundSymbol(boundName, tree1.tpe & pt, tree.span)
3427-
untpd.Bind(boundName, tree1).withType(wildcardSym.typeRef)
3428-
case tree1 =>
3429-
tree1
3430-
else tree1
3436+
val wildcardSym = newPatternBoundSymbol(boundName, tree2.tpe & pt, tree.span)
3437+
untpd.Bind(boundName, tree2).withType(wildcardSym.typeRef)
3438+
case tree2 =>
3439+
tree2
3440+
else tree2
3441+
3442+
private def flattenOr(tp: Type)(using Context): Type =
3443+
var options: List[Type] = Nil
3444+
var doUpdate: Boolean = false
3445+
3446+
def offer(next: Type): Unit =
3447+
// By checking at insert time, we will never add an element to the internal state if it is invalidated by
3448+
// a later element. Thus as extract time, we only need to validate for those prepended after that point
3449+
next match
3450+
case OrType(o1, o2) =>
3451+
offer(o1)
3452+
offer(o2)
3453+
case _ =>
3454+
if (!options.exists(prior => next <:< prior))
3455+
options = next :: options
3456+
else
3457+
doUpdate = true
3458+
3459+
offer(tp)
3460+
if (doUpdate)
3461+
val typesToAdd = options.reverse.tails.flatMap {
3462+
case curr :: allLaterAdditions
3463+
if !allLaterAdditions.exists(later => curr <:< later) =>
3464+
Some(curr)
3465+
case _ =>
3466+
doUpdate = true
3467+
None
3468+
}
3469+
3470+
def addHelper(add: Type, orTree: List[Option[Type]], iter: Int = 0): List[Option[Type]] =
3471+
orTree match
3472+
case None :: more =>
3473+
var res = Some(add) :: more
3474+
for (i <- 1 to iter) {
3475+
res = None :: res
3476+
}
3477+
res
3478+
case Some(next) :: more =>
3479+
addHelper(add | next, more, iter + 1)
3480+
case Nil =>
3481+
var res: List[Option[Type]] = List(Some(add))
3482+
for (i <- 1 to iter) {
3483+
res = None :: res
3484+
}
3485+
res
3486+
3487+
val res = typesToAdd.foldLeft[List[Option[Type]]](Nil) {
3488+
case (orTree, add) =>
3489+
addHelper(add, orTree)
3490+
}.flatten.reduceLeft(_ | _)
3491+
//println(s"${tp.show} ===> ${res.show} (${ctx.tree.show})")
3492+
res
3493+
else
3494+
tp
3495+
34313496

34323497
def typedPattern(tree: untpd.Tree, selType: Type = WildcardType)(using Context): Tree =
34333498
withMode(Mode.Pattern)(typed(tree, selType))

compiler/test-resources/repl/10693

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
scala> def test[A, B](a: A, b: B): A | B = a
2+
def test[A, B](a: A, b: B): A | B
3+
4+
scala> def d0 = test("string", 1)
5+
def d0: String | Int
6+
7+
scala> def d1 = test(1, "string")
8+
def d1: Int | String
9+
10+
scala> def d2 = test(d0, d1)
11+
def d2: Int | String
12+
13+
scala> def d3 = test(d1, d0)
14+
def d3: String | Int
15+
16+
scala> def d4 = test(d2, d3)
17+
def d4: String | Int
18+
19+
scala> def d5 = test(d3, d2)
20+
def d5: Int | String
21+
22+
scala> def d6 = test(d4, d5)
23+
def d6: Int | String

tests/pos/i10693.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
object Example {
2+
def test[A, B](a: A, b: B): A | B = a
3+
4+
val v0 = test("string", 1)
5+
val v1 = test(1, "string")
6+
val v2 = test(v0, v1)
7+
val v3 = test(v1, v0)
8+
val v4 = test(v2, v3)
9+
val v5 = test(v3, v2)
10+
val v6 = test(v4, v5)
11+
12+
def d0 = test("string", 1)
13+
def d1 = test(1, "string")
14+
def d2 = test(d0, d1)
15+
def d3 = test(d1, d0)
16+
def d4 = test(d2, d3)
17+
def d5 = test(d3, d2)
18+
def d6 = test(d4, d5)
19+
}

0 commit comments

Comments
 (0)