Skip to content

Commit bd54027

Browse files
Merge pull request #6319 from dotty-staging/fix-6314
Fix #6314: match type unsoundness
2 parents 8ffed08 + 74ca923 commit bd54027

17 files changed

+305
-65
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,6 @@ class Definitions {
564564
lazy val BoxedUnitModule: TermSymbol = ctx.requiredModule("java.lang.Void")
565565

566566
lazy val ByNameParamClass2x: ClassSymbol = enterSpecialPolyClass(tpnme.BYNAME_PARAM_CLASS, Covariant, Seq(AnyType))
567-
lazy val EqualsPatternClass: ClassSymbol = enterSpecialPolyClass(tpnme.EQUALS_PATTERN, EmptyFlags, Seq(AnyType))
568567

569568
lazy val RepeatedParamClass: ClassSymbol = enterSpecialPolyClass(tpnme.REPEATED_PARAM_CLASS, Covariant, Seq(ObjectType, SeqType))
570569

@@ -1385,8 +1384,7 @@ class Definitions {
13851384
AnyValClass,
13861385
NullClass,
13871386
NothingClass,
1388-
SingletonClass,
1389-
EqualsPatternClass)
1387+
SingletonClass)
13901388

13911389
lazy val syntheticCoreClasses: List[Symbol] = syntheticScalaClasses ++ List(
13921390
EmptyPackageVal,

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ class TypeApplications(val self: Type) extends AnyVal {
405405
case dealiased: LazyRef =>
406406
LazyRef(c => dealiased.ref(c).appliedTo(args))
407407
case dealiased: WildcardType =>
408-
WildcardType(dealiased.optBounds.appliedTo(args).bounds)
408+
WildcardType(dealiased.optBounds.orElse(TypeBounds.empty).appliedTo(args).bounds)
409409
case dealiased: TypeRef if dealiased.symbol == defn.NothingClass =>
410410
dealiased
411411
case dealiased =>

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+61-50
Original file line numberDiff line numberDiff line change
@@ -2057,9 +2057,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
20572057
/** Returns last check's debug mode, if explicitly enabled. */
20582058
def lastTrace(): String = ""
20592059

2060-
/** Do `tp1` and `tp2` share a non-null inhabitant?
2060+
/** Are `tp1` and `tp2` disjoint types?
20612061
*
2062-
* `false` implies that we found a proof; uncertainty default to `true`.
2062+
* `true` implies that we found a proof; uncertainty default to `false`.
20632063
*
20642064
* Proofs rely on the following properties of Scala types:
20652065
*
@@ -2068,8 +2068,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
20682068
* 3. ConstantTypes with distinc values are non intersecting
20692069
* 4. There is no value of type Nothing
20702070
*/
2071-
def intersecting(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = {
2072-
// println(s"intersecting(${tp1.show}, ${tp2.show})")
2071+
def disjoint(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = {
2072+
// println(s"disjoint(${tp1.show}, ${tp2.show})")
20732073
/** Can we enumerate all instantiations of this type? */
20742074
def isClosedSum(tp: Symbol): Boolean =
20752075
tp.is(Sealed) && tp.is(AbstractOrTrait) && !tp.hasAnonymousChild
@@ -2082,46 +2082,35 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
20822082
sym.children.map(x => ctx.refineUsingParent(tp, x)).filter(_.exists)
20832083

20842084
(tp1.dealias, tp2.dealias) match {
2085+
case (tp1: TypeRef, tp2: TypeRef) if tp1.symbol == defn.SingletonClass || tp2.symbol == defn.SingletonClass =>
2086+
false
20852087
case (tp1: ConstantType, tp2: ConstantType) =>
2086-
tp1 == tp2
2088+
tp1 != tp2
20872089
case (tp1: TypeRef, tp2: TypeRef) if tp1.symbol.isClass && tp2.symbol.isClass =>
20882090
val cls1 = tp1.classSymbol
20892091
val cls2 = tp2.classSymbol
20902092
if (cls1.derivesFrom(cls2) || cls2.derivesFrom(cls1)) {
2091-
true
2093+
false
20922094
} else {
20932095
if (cls1.is(Final) || cls2.is(Final))
20942096
// One of these types is final and they are not mutually
20952097
// subtype, so they must be unrelated.
2096-
false
2098+
true
20972099
else if (!cls2.is(Trait) && !cls1.is(Trait))
20982100
// Both of these types are classes and they are not mutually
20992101
// subtype, so they must be unrelated by single inheritance
21002102
// of classes.
2101-
false
2103+
true
21022104
else if (isClosedSum(cls1))
2103-
decompose(cls1, tp1).exists(x => intersecting(x, tp2))
2105+
decompose(cls1, tp1).forall(x => disjoint(x, tp2))
21042106
else if (isClosedSum(cls2))
2105-
decompose(cls2, tp2).exists(x => intersecting(x, tp1))
2107+
decompose(cls2, tp2).forall(x => disjoint(x, tp1))
21062108
else
2107-
true
2109+
false
21082110
}
21092111
case (AppliedType(tycon1, args1), AppliedType(tycon2, args2)) if tycon1 == tycon2 =>
2110-
// Unboxed xs.zip(ys).zip(zs).forall { case ((a, b), c) => f(a, b, c) }
2111-
def zip_zip_forall[A, B, C](xs: List[A], ys: List[B], zs: List[C])(f: (A, B, C) => Boolean): Boolean = {
2112-
xs match {
2113-
case x :: xs => ys match {
2114-
case y :: ys => zs match {
2115-
case z :: zs => f(x, y, z) && zip_zip_forall(xs, ys, zs)(f)
2116-
case _ => true
2117-
}
2118-
case _ => true
2119-
}
2120-
case _ => true
2121-
}
2122-
}
2123-
def covariantIntersecting(tp1: Type, tp2: Type, tparam: TypeParamInfo): Boolean = {
2124-
intersecting(tp1, tp2) || {
2112+
def covariantDisjoint(tp1: Type, tp2: Type, tparam: TypeParamInfo): Boolean = {
2113+
disjoint(tp1, tp2) && {
21252114
// We still need to proof that `Nothing` is not a valid
21262115
// instantiation of this type parameter. We have two ways
21272116
// to get to that conclusion:
@@ -2139,24 +2128,24 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
21392128
case _ =>
21402129
false
21412130
}
2142-
lowerBoundedByNothing && !typeUsedAsField
2131+
!lowerBoundedByNothing || typeUsedAsField
21432132
}
21442133
}
21452134

2146-
zip_zip_forall(args1, args2, tycon1.typeParams) {
2135+
(args1, args2, tycon1.typeParams).zipped.exists {
21472136
(arg1, arg2, tparam) =>
21482137
val v = tparam.paramVariance
21492138
if (v > 0)
2150-
covariantIntersecting(arg1, arg2, tparam)
2139+
covariantDisjoint(arg1, arg2, tparam)
21512140
else if (v < 0)
21522141
// Contravariant case: a value where this type parameter is
21532142
// instantiated to `Any` belongs to both types.
2154-
true
2143+
false
21552144
else
2156-
covariantIntersecting(arg1, arg2, tparam) && (isSameType(arg1, arg2) || {
2145+
covariantDisjoint(arg1, arg2, tparam) || !isSameType(arg1, arg2) && {
21572146
// We can only trust a "no" from `isSameType` when both
21582147
// `arg1` and `arg2` are fully instantiated.
2159-
val fullyInstantiated = new TypeAccumulator[Boolean] {
2148+
def fullyInstantiated(tp: Type): Boolean = new TypeAccumulator[Boolean] {
21602149
override def apply(x: Boolean, t: Type) =
21612150
x && {
21622151
t match {
@@ -2165,34 +2154,36 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
21652154
case _ => foldOver(x, t)
21662155
}
21672156
}
2168-
}
2169-
!(fullyInstantiated.apply(true, arg1) &&
2170-
fullyInstantiated.apply(true, arg2))
2171-
})
2157+
}.apply(true, tp)
2158+
fullyInstantiated(arg1) && fullyInstantiated(arg2)
2159+
}
21722160
}
21732161
case (tp1: HKLambda, tp2: HKLambda) =>
2174-
intersecting(tp1.resType, tp2.resType)
2162+
disjoint(tp1.resType, tp2.resType)
21752163
case (_: HKLambda, _) =>
2176-
// The intersection is ill kinded and therefore empty.
2177-
false
2164+
// The intersection of these two types would be ill kinded, they are therefore disjoint.
2165+
true
21782166
case (_, _: HKLambda) =>
2179-
false
2167+
true
21802168
case (tp1: OrType, _) =>
2181-
intersecting(tp1.tp1, tp2) || intersecting(tp1.tp2, tp2)
2169+
disjoint(tp1.tp1, tp2) && disjoint(tp1.tp2, tp2)
21822170
case (_, tp2: OrType) =>
2183-
intersecting(tp1, tp2.tp1) || intersecting(tp1, tp2.tp2)
2171+
disjoint(tp1, tp2.tp1) && disjoint(tp1, tp2.tp2)
2172+
case (tp1: AndType, tp2: AndType) =>
2173+
(disjoint(tp1.tp1, tp2.tp1) || disjoint(tp1.tp2, tp2.tp2)) &&
2174+
(disjoint(tp1.tp1, tp2.tp2) || disjoint(tp1.tp2, tp2.tp1))
21842175
case (tp1: AndType, _) =>
2185-
intersecting(tp1.tp1, tp2) && intersecting(tp1.tp2, tp2) && intersecting(tp1.tp1, tp1.tp2)
2176+
disjoint(tp1.tp2, tp2) || disjoint(tp1.tp1, tp2)
21862177
case (_, tp2: AndType) =>
2187-
intersecting(tp1, tp2.tp1) && intersecting(tp1, tp2.tp2) && intersecting(tp2.tp1, tp2.tp2)
2178+
disjoint(tp1, tp2.tp2) || disjoint(tp1, tp2.tp1)
21882179
case (tp1: TypeProxy, tp2: TypeProxy) =>
2189-
intersecting(tp1.underlying, tp2) && intersecting(tp1, tp2.underlying)
2180+
disjoint(tp1.underlying, tp2) || disjoint(tp1, tp2.underlying)
21902181
case (tp1: TypeProxy, _) =>
2191-
intersecting(tp1.underlying, tp2)
2182+
disjoint(tp1.underlying, tp2)
21922183
case (_, tp2: TypeProxy) =>
2193-
intersecting(tp1, tp2.underlying)
2184+
disjoint(tp1, tp2.underlying)
21942185
case _ =>
2195-
true
2186+
false
21962187
}
21972188
}
21982189
}
@@ -2321,6 +2312,24 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
23212312
case _ =>
23222313
cas
23232314
}
2315+
def widenAbstractTypes(tp: Type): Type = new TypeMap {
2316+
def apply(tp: Type) = tp match {
2317+
case tp: TypeRef =>
2318+
if (tp.symbol.isAbstractOrParamType | tp.symbol.isOpaqueAlias)
2319+
WildcardType
2320+
else tp.info match {
2321+
case TypeAlias(alias) =>
2322+
val alias1 = widenAbstractTypes(alias)
2323+
if (alias1 ne alias) alias1 else tp
2324+
case _ => mapOver(tp)
2325+
}
2326+
2327+
case tp: TypeVar if !tp.isInstantiated => WildcardType
2328+
case _: TypeParamRef => WildcardType
2329+
case _ => mapOver(tp)
2330+
}
2331+
}.apply(tp)
2332+
23242333
val defn.MatchCase(pat, body) = cas1
23252334
if (isSubType(scrut, pat))
23262335
// `scrut` is a subtype of `pat`: *It's a Match!*
@@ -2333,12 +2342,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
23332342
body
23342343
}
23352344
}
2336-
else if (intersecting(scrut, pat))
2345+
else if (isSubType(widenAbstractTypes(scrut), widenAbstractTypes(pat)))
23372346
Some(NoType)
2338-
else
2347+
else if (disjoint(scrut, pat))
23392348
// We found a proof that `scrut` and `pat` are incompatible.
23402349
// The search continues.
23412350
None
2351+
else
2352+
Some(NoType)
23422353
}
23432354

23442355
def recur(cases: List[Type]): Type = cases match {

compiler/src/dotty/tools/dotc/core/Types.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -3913,8 +3913,7 @@ object Types {
39133913
myReduced =
39143914
trace(i"reduce match type $this $hashCode", typr, show = true) {
39153915
try
3916-
if (defn.isBottomType(scrutinee)) defn.NothingType
3917-
else typeComparer.matchCases(scrutinee, cases)(trackingCtx)
3916+
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
39183917
catch {
39193918
case ex: Throwable =>
39203919
handleRecursive("reduce type ", i"$scrutinee match ...", ex)

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,11 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
313313
// Since projections of types don't include null, intersection with null is empty.
314314
return Empty
315315
}
316-
val res = ctx.typeComparer.intersecting(tp1, tp2)
316+
val res = ctx.typeComparer.disjoint(tp1, tp2)
317317

318-
debug.println(s"atomic intersection: ${AndType(tp1, tp2).show} = ${res}")
318+
debug.println(s"atomic intersection: ${AndType(tp1, tp2).show} = ${!res}")
319319

320-
if (!res) Empty
320+
if (res) Empty
321321
else if (tp1.isSingleton) Typ(tp1, true)
322322
else if (tp2.isSingleton) Typ(tp2, true)
323323
else Typ(AndType(tp1, tp2), true)
@@ -503,7 +503,7 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
503503

504504
def inhabited(tp: Type): Boolean =
505505
tp.dealias match {
506-
case AndType(tp1, tp2) => ctx.typeComparer.intersecting(tp1, tp2)
506+
case AndType(tp1, tp2) => !ctx.typeComparer.disjoint(tp1, tp2)
507507
case OrType(tp1, tp2) => inhabited(tp1) || inhabited(tp2)
508508
case tp: RefinedType => inhabited(tp.parent)
509509
case tp: TypeRef => inhabited(tp.prefix)

library/src-3.x/scala/compiletime/package.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ package object compiletime {
1111
inline def constValue[T]: T = ???
1212

1313
type S[X <: Int] <: Int
14-
}
14+
}

tests/neg/6314-1.scala

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object G {
2+
final class X
3+
final class Y
4+
5+
trait FooSig {
6+
type Type
7+
def apply[F[_]](fa: F[X & Y]): F[Y & Type]
8+
}
9+
val Foo: FooSig = new FooSig {
10+
type Type = X & Y
11+
def apply[F[_]](fa: F[X & Y]): F[Y & Type] = fa
12+
}
13+
type Foo = Foo.Type
14+
15+
type Bar[A] = A match {
16+
case X & Y => String
17+
case Y => Int
18+
}
19+
20+
def main(args: Array[String]): Unit = {
21+
val a: Bar[X & Y] = "hello"
22+
val i: Bar[Y & Foo] = Foo.apply[Bar](a)
23+
val b: Int = i // error
24+
println(b + 1)
25+
}
26+
}

tests/neg/6314-2.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
object G {
2+
final class X
3+
final class Y
4+
5+
opaque type Foo = Nothing // or X & Y
6+
object Foo {
7+
def apply[F[_]](fa: F[X & Foo]): F[Y & Foo] = fa
8+
}
9+
10+
type Bar[A] = A match {
11+
case X => String
12+
case Y => Int
13+
}
14+
15+
val a: Bar[X & Foo] = "hello"
16+
val b: Bar[Y & Foo] = 1 // error
17+
18+
def main(args: Array[String]): Unit = {
19+
val a: Bar[X & Foo] = "hello"
20+
val i: Bar[Y & Foo] = Foo.apply[Bar](a)
21+
val b: Int = i // error
22+
println(b + 1)
23+
}
24+
}

tests/neg/6314-3.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
object G {
2+
trait Wizzle[L <: Int with Singleton] {
3+
type Bar[A] = A match {
4+
case 0 => String
5+
case L => Int
6+
}
7+
8+
def left(fa: String): Bar[0] = fa
9+
def right(fa: Bar[L]): Int = fa // error
10+
11+
def center[F[_]](fa: F[0]): F[L]
12+
13+
def run: String => Int = left andThen center[Bar] andThen right
14+
}
15+
16+
class Wozzle extends Wizzle[0] {
17+
def center[F[_]](fa: F[0]): F[0] = fa
18+
}
19+
20+
def main(args: Array[String]): Unit = {
21+
val coerce: String => Int = (new Wozzle).run
22+
println(coerce("hello") + 1)
23+
}
24+
}

tests/neg/6314-4.scala

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
object G {
2+
trait Wizzle {
3+
type X <: Int with Singleton
4+
type Y <: Int with Singleton
5+
6+
type Bar[A] = A match {
7+
case X => String
8+
case Y => Int
9+
}
10+
11+
def left(fa: String): Bar[X] = fa
12+
def center[F[_]](fa: F[X]): F[Y]
13+
def right(fa: Bar[Y]): Int = fa // error
14+
15+
def run: String => Int = left andThen center[Bar] andThen right
16+
}
17+
18+
class Wozzle extends Wizzle {
19+
type X = 0
20+
type Y = 0
21+
def center[F[_]](fa: F[X]): F[Y] = fa
22+
}
23+
24+
def main(args: Array[String]): Unit = {
25+
val coerce: String => Int = (new Wozzle).run
26+
println(coerce("hello") + 1)
27+
}
28+
}

0 commit comments

Comments
 (0)