Skip to content

Commit 15b4385

Browse files
committed
Fix i11694: extract function type and SAM in union type
1 parent 77b0ae0 commit 15b4385

File tree

3 files changed

+58
-10
lines changed

3 files changed

+58
-10
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

+35-10
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,20 @@ class Typer extends Namer
11041104
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
11051105
case _ => mapOver(t)
11061106
}
1107+
def extractInUnion(t: Type): Seq[Type] = t match {
1108+
case t: OrType =>
1109+
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
1110+
case t: TypeParamRef =>
1111+
extractInUnion(ctx.typerState.constraint.entry(t).bounds.hi)
1112+
case t if defn.isNonRefinedFunction(t) =>
1113+
Seq(t)
1114+
case SAMType(_: MethodType) =>
1115+
Seq(t)
1116+
case _ =>
1117+
Nil
1118+
}
1119+
def defaultResult = (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1120+
11071121
val pt1 = pt.stripTypeVar.dealias
11081122
if (pt1 ne pt1.dropDependentRefinement)
11091123
&& defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
@@ -1112,22 +1126,25 @@ class Typer extends Namer
11121126
i"""Implementation restriction: Expected result type $pt1
11131127
|is a curried dependent context function type. Such types are not yet supported.""",
11141128
tree.srcPos)
1115-
pt1 match {
1129+
1130+
val elems = extractInUnion(pt1)
1131+
if elems.length != 1 then
1132+
// The union type containing multiple function types is ignored
1133+
defaultResult
1134+
else elems.head match {
11161135
case pt1 if defn.isNonRefinedFunction(pt1) =>
11171136
// if expected parameter type(s) are wildcards, approximate from below.
11181137
// if expected result type is a wildcard, approximate from above.
11191138
// this can type the greatest set of admissible closures.
11201139
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
11211140
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
11221141
(formals,
1123-
if (sam.isResultDependent)
1124-
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
1125-
else
1126-
typeTree(restpe))
1127-
case tp: TypeParamRef =>
1128-
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
1142+
if (sam.isResultDependent)
1143+
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
1144+
else
1145+
typeTree(restpe))
11291146
case _ =>
1130-
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1147+
defaultResult
11311148
}
11321149
}
11331150

@@ -1355,14 +1372,22 @@ class Typer extends Namer
13551372
}
13561373

13571374
def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
1375+
def extractInUnion(t: Type): Seq[Type] = t match {
1376+
case t: OrType =>
1377+
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
1378+
case SAMType(_) =>
1379+
Seq(t)
1380+
case _ =>
1381+
Nil
1382+
}
13581383
val env1 = tree.env mapconserve (typed(_))
13591384
val meth1 = typedUnadapted(tree.meth)
13601385
val target =
13611386
if (tree.tpt.isEmpty)
13621387
meth1.tpe.widen match {
13631388
case mt: MethodType =>
1364-
pt.stripNull match {
1365-
case pt @ SAMType(sam)
1389+
extractInUnion(pt) match {
1390+
case Seq(pt @ SAMType(sam))
13661391
if !defn.isFunctionType(pt) && mt <:< sam =>
13671392
// SAMs of the form C[?] where C is a class cannot be conversion targets.
13681393
// The resulting class `class $anon extends C[?] {...}` would be illegal,

tests/explicit-nulls/pos/i11694.scala

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = {
2+
val x = new java.util.ArrayList[String]()
3+
val y = x.stream().nn.filter(s => s.nn.length > 0)
4+
}

tests/neg/i11694.scala

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def test1 = {
2+
def f11: (Int => Int) | Unit = x => x + 1
3+
def f12: Null | (Int => Int) = x => x + 1
4+
5+
def f21: (Int => Int) | Null = x => x + 1
6+
def f22: Null | (Int => Int) = x => x + 1
7+
}
8+
9+
def test2 = {
10+
def f1: (Int => String) | (Int => Int) | Null = x => x + 1 // error
11+
def f2: (Int => String) | Function[String, Int] | Null = x => "" + x // error
12+
def f3: Function[Int, Int] | Function[String, Int] | Null = x => x + 1 // error
13+
}
14+
15+
def test3 = {
16+
import java.util.function.Function
17+
val f1: Function[String, Int] | Unit = x => x.length
18+
val f2: Function[String, Int] | Null = x => x.length
19+
}

0 commit comments

Comments
 (0)