Skip to content

Fix #6314: match type unsoundness #6319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 17, 2019
4 changes: 1 addition & 3 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ class Definitions {
lazy val BoxedUnitModule: TermSymbol = ctx.requiredModule("java.lang.Void")

lazy val ByNameParamClass2x: ClassSymbol = enterSpecialPolyClass(tpnme.BYNAME_PARAM_CLASS, Covariant, Seq(AnyType))
lazy val EqualsPatternClass: ClassSymbol = enterSpecialPolyClass(tpnme.EQUALS_PATTERN, EmptyFlags, Seq(AnyType))

lazy val RepeatedParamClass: ClassSymbol = enterSpecialPolyClass(tpnme.REPEATED_PARAM_CLASS, Covariant, Seq(ObjectType, SeqType))

Expand Down Expand Up @@ -1375,8 +1374,7 @@ class Definitions {
AnyValClass,
NullClass,
NothingClass,
SingletonClass,
EqualsPatternClass)
SingletonClass)

lazy val syntheticCoreClasses: List[Symbol] = syntheticScalaClasses ++ List(
EmptyPackageVal,
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class TypeApplications(val self: Type) extends AnyVal {
case dealiased: LazyRef =>
LazyRef(c => dealiased.ref(c).appliedTo(args))
case dealiased: WildcardType =>
WildcardType(dealiased.optBounds.appliedTo(args).bounds)
WildcardType(dealiased.optBounds.orElse(TypeBounds.empty).appliedTo(args).bounds)
case dealiased: TypeRef if dealiased.symbol == defn.NothingClass =>
dealiased
case dealiased =>
Expand Down
111 changes: 61 additions & 50 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
/** Returns last check's debug mode, if explicitly enabled. */
def lastTrace(): String = ""

/** Do `tp1` and `tp2` share a non-null inhabitant?
/** Are `tp1` and `tp2` disjoint types?
*
* `false` implies that we found a proof; uncertainty default to `true`.
* `true` implies that we found a proof; uncertainty default to `false`.
*
* Proofs rely on the following properties of Scala types:
*
Expand All @@ -1906,8 +1906,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
* 3. ConstantTypes with distinc values are non intersecting
* 4. There is no value of type Nothing
*/
def intersecting(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = {
// println(s"intersecting(${tp1.show}, ${tp2.show})")
def disjoint(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = {
// println(s"disjoint(${tp1.show}, ${tp2.show})")
/** Can we enumerate all instantiations of this type? */
def isClosedSum(tp: Symbol): Boolean =
tp.is(Sealed) && tp.is(AbstractOrTrait) && !tp.hasAnonymousChild
Expand All @@ -1920,46 +1920,35 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
sym.children.map(x => ctx.refineUsingParent(tp, x)).filter(_.exists)

(tp1.dealias, tp2.dealias) match {
case (tp1: TypeRef, tp2: TypeRef) if tp1.symbol == defn.SingletonClass || tp2.symbol == defn.SingletonClass =>
false
case (tp1: ConstantType, tp2: ConstantType) =>
tp1 == tp2
tp1 != tp2
case (tp1: TypeRef, tp2: TypeRef) if tp1.symbol.isClass && tp2.symbol.isClass =>
val cls1 = tp1.classSymbol
val cls2 = tp2.classSymbol
if (cls1.derivesFrom(cls2) || cls2.derivesFrom(cls1)) {
true
false
} else {
if (cls1.is(Final) || cls2.is(Final))
// One of these types is final and they are not mutually
// subtype, so they must be unrelated.
false
true
else if (!cls2.is(Trait) && !cls1.is(Trait))
// Both of these types are classes and they are not mutually
// subtype, so they must be unrelated by single inheritance
// of classes.
false
true
else if (isClosedSum(cls1))
decompose(cls1, tp1).exists(x => intersecting(x, tp2))
decompose(cls1, tp1).forall(x => disjoint(x, tp2))
else if (isClosedSum(cls2))
decompose(cls2, tp2).exists(x => intersecting(x, tp1))
decompose(cls2, tp2).forall(x => disjoint(x, tp1))
else
true
false
}
case (AppliedType(tycon1, args1), AppliedType(tycon2, args2)) if tycon1 == tycon2 =>
// Unboxed xs.zip(ys).zip(zs).forall { case ((a, b), c) => f(a, b, c) }
def zip_zip_forall[A, B, C](xs: List[A], ys: List[B], zs: List[C])(f: (A, B, C) => Boolean): Boolean = {
xs match {
case x :: xs => ys match {
case y :: ys => zs match {
case z :: zs => f(x, y, z) && zip_zip_forall(xs, ys, zs)(f)
case _ => true
}
case _ => true
}
case _ => true
}
}
def covariantIntersecting(tp1: Type, tp2: Type, tparam: TypeParamInfo): Boolean = {
intersecting(tp1, tp2) || {
def covariantDisjoint(tp1: Type, tp2: Type, tparam: TypeParamInfo): Boolean = {
disjoint(tp1, tp2) && {
// We still need to proof that `Nothing` is not a valid
// instantiation of this type parameter. We have two ways
// to get to that conclusion:
Expand All @@ -1977,24 +1966,24 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
case _ =>
false
}
lowerBoundedByNothing && !typeUsedAsField
!lowerBoundedByNothing || typeUsedAsField
}
}

zip_zip_forall(args1, args2, tycon1.typeParams) {
(args1, args2, tycon1.typeParams).zipped.exists {
(arg1, arg2, tparam) =>
val v = tparam.paramVariance
if (v > 0)
covariantIntersecting(arg1, arg2, tparam)
covariantDisjoint(arg1, arg2, tparam)
else if (v < 0)
// Contravariant case: a value where this type parameter is
// instantiated to `Any` belongs to both types.
true
false
else
covariantIntersecting(arg1, arg2, tparam) && (isSameType(arg1, arg2) || {
covariantDisjoint(arg1, arg2, tparam) || !isSameType(arg1, arg2) && {
// We can only trust a "no" from `isSameType` when both
// `arg1` and `arg2` are fully instantiated.
val fullyInstantiated = new TypeAccumulator[Boolean] {
def fullyInstantiated(tp: Type): Boolean = new TypeAccumulator[Boolean] {
override def apply(x: Boolean, t: Type) =
x && {
t match {
Expand All @@ -2003,34 +1992,36 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
case _ => foldOver(x, t)
}
}
}
!(fullyInstantiated.apply(true, arg1) &&
fullyInstantiated.apply(true, arg2))
})
}.apply(true, tp)
fullyInstantiated(arg1) && fullyInstantiated(arg2)
}
}
case (tp1: HKLambda, tp2: HKLambda) =>
intersecting(tp1.resType, tp2.resType)
disjoint(tp1.resType, tp2.resType)
case (_: HKLambda, _) =>
// The intersection is ill kinded and therefore empty.
false
// The intersection of these two types would be ill kinded, they are therefore disjoint.
true
case (_, _: HKLambda) =>
false
true
case (tp1: OrType, _) =>
intersecting(tp1.tp1, tp2) || intersecting(tp1.tp2, tp2)
disjoint(tp1.tp1, tp2) && disjoint(tp1.tp2, tp2)
case (_, tp2: OrType) =>
intersecting(tp1, tp2.tp1) || intersecting(tp1, tp2.tp2)
disjoint(tp1, tp2.tp1) && disjoint(tp1, tp2.tp2)
case (tp1: AndType, tp2: AndType) =>
(disjoint(tp1.tp1, tp2.tp1) || disjoint(tp1.tp2, tp2.tp2)) &&
(disjoint(tp1.tp1, tp2.tp2) || disjoint(tp1.tp2, tp2.tp1))
case (tp1: AndType, _) =>
intersecting(tp1.tp1, tp2) && intersecting(tp1.tp2, tp2) && intersecting(tp1.tp1, tp1.tp2)
disjoint(tp1.tp2, tp2) || disjoint(tp1.tp1, tp2)
case (_, tp2: AndType) =>
intersecting(tp1, tp2.tp1) && intersecting(tp1, tp2.tp2) && intersecting(tp2.tp1, tp2.tp2)
disjoint(tp1, tp2.tp2) || disjoint(tp1, tp2.tp1)
case (tp1: TypeProxy, tp2: TypeProxy) =>
intersecting(tp1.underlying, tp2) && intersecting(tp1, tp2.underlying)
disjoint(tp1.underlying, tp2) || disjoint(tp1, tp2.underlying)
case (tp1: TypeProxy, _) =>
intersecting(tp1.underlying, tp2)
disjoint(tp1.underlying, tp2)
case (_, tp2: TypeProxy) =>
intersecting(tp1, tp2.underlying)
disjoint(tp1, tp2.underlying)
case _ =>
true
false
}
}
}
Expand Down Expand Up @@ -2159,6 +2150,24 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
case _ =>
cas
}
def widenAbstractTypes(tp: Type): Type = new TypeMap {
def apply(tp: Type) = tp match {
case tp: TypeRef =>
if (tp.symbol.isAbstractOrParamType | tp.symbol.isOpaqueAlias)
WildcardType
else tp.info match {
case TypeAlias(alias) =>
val alias1 = widenAbstractTypes(alias)
if (alias1 ne alias) alias1 else tp
case _ => mapOver(tp)
}

case tp: TypeVar if !tp.isInstantiated => WildcardType
case _: TypeParamRef => WildcardType
case _ => mapOver(tp)
}
}.apply(tp)

val defn.MatchCase(pat, body) = cas1
if (isSubType(scrut, pat))
// `scrut` is a subtype of `pat`: *It's a Match!*
Expand All @@ -2171,12 +2180,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
body
}
}
else if (intersecting(scrut, pat))
else if (isSubType(widenAbstractTypes(scrut), widenAbstractTypes(pat)))
Some(NoType)
else
else if (disjoint(scrut, pat))
// We found a proof that `scrut` and `pat` are incompatible.
// The search continues.
None
else
Some(NoType)
}

def recur(cases: List[Type]): Type = cases match {
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3823,8 +3823,7 @@ object Types {
myReduced =
trace(i"reduce match type $this $hashCode", typr, show = true) {
try
if (defn.isBottomType(scrutinee)) defn.NothingType
else typeComparer.matchCases(scrutinee, cases)(trackingCtx)
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
catch {
case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,11 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
// Since projections of types don't include null, intersection with null is empty.
return Empty
}
val res = ctx.typeComparer.intersecting(tp1, tp2)
val res = ctx.typeComparer.disjoint(tp1, tp2)

debug.println(s"atomic intersection: ${AndType(tp1, tp2).show} = ${res}")
debug.println(s"atomic intersection: ${AndType(tp1, tp2).show} = ${!res}")

if (!res) Empty
if (res) Empty
else if (tp1.isSingleton) Typ(tp1, true)
else if (tp2.isSingleton) Typ(tp2, true)
else Typ(AndType(tp1, tp2), true)
Expand Down Expand Up @@ -498,7 +498,7 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {

def inhabited(tp: Type): Boolean =
tp.dealias match {
case AndType(tp1, tp2) => ctx.typeComparer.intersecting(tp1, tp2)
case AndType(tp1, tp2) => !ctx.typeComparer.disjoint(tp1, tp2)
case OrType(tp1, tp2) => inhabited(tp1) || inhabited(tp2)
case tp: RefinedType => inhabited(tp.parent)
case tp: TypeRef => inhabited(tp.prefix)
Expand Down
2 changes: 1 addition & 1 deletion library/src-3.x/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ package object compiletime {
inline def constValue[T]: T = ???

type S[X <: Int] <: Int
}
}
26 changes: 26 additions & 0 deletions tests/neg/6314-1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
object G {
final class X
final class Y

trait FooSig {
type Type
def apply[F[_]](fa: F[X & Y]): F[Y & Type]
}
val Foo: FooSig = new FooSig {
type Type = X & Y
def apply[F[_]](fa: F[X & Y]): F[Y & Type] = fa
}
type Foo = Foo.Type

type Bar[A] = A match {
case X & Y => String
case Y => Int
}

def main(args: Array[String]): Unit = {
val a: Bar[X & Y] = "hello"
val i: Bar[Y & Foo] = Foo.apply[Bar](a)
val b: Int = i // error
println(b + 1)
}
}
24 changes: 24 additions & 0 deletions tests/neg/6314-2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
object G {
final class X
final class Y

opaque type Foo = Nothing // or X & Y
object Foo {
def apply[F[_]](fa: F[X & Foo]): F[Y & Foo] = fa
}

type Bar[A] = A match {
case X => String
case Y => Int
}

val a: Bar[X & Foo] = "hello"
val b: Bar[Y & Foo] = 1 // error

def main(args: Array[String]): Unit = {
val a: Bar[X & Foo] = "hello"
val i: Bar[Y & Foo] = Foo.apply[Bar](a)
val b: Int = i // error
println(b + 1)
}
}
24 changes: 24 additions & 0 deletions tests/neg/6314-3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
object G {
trait Wizzle[L <: Int with Singleton] {
type Bar[A] = A match {
case 0 => String
case L => Int
}

def left(fa: String): Bar[0] = fa
def right(fa: Bar[L]): Int = fa // error

def center[F[_]](fa: F[0]): F[L]

def run: String => Int = left andThen center[Bar] andThen right
}

class Wozzle extends Wizzle[0] {
def center[F[_]](fa: F[0]): F[0] = fa
}

def main(args: Array[String]): Unit = {
val coerce: String => Int = (new Wozzle).run
println(coerce("hello") + 1)
}
}
28 changes: 28 additions & 0 deletions tests/neg/6314-4.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
object G {
trait Wizzle {
type X <: Int with Singleton
type Y <: Int with Singleton

type Bar[A] = A match {
case X => String
case Y => Int
}

def left(fa: String): Bar[X] = fa
def center[F[_]](fa: F[X]): F[Y]
def right(fa: Bar[Y]): Int = fa // error

def run: String => Int = left andThen center[Bar] andThen right
}

class Wozzle extends Wizzle {
type X = 0
type Y = 0
def center[F[_]](fa: F[X]): F[Y] = fa
}

def main(args: Array[String]): Unit = {
val coerce: String => Int = (new Wozzle).run
println(coerce("hello") + 1)
}
}
Loading