Skip to content

Commit ee0dd7a

Browse files
authored
Fix #21619: Refactor NotNullInfo to record every reference which is retracted once. (#21624)
This PR improves the flow typing for returning and exceptions. The `NotNullInfo` is defined as following now: ```scala case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): ``` * `retracted` contains variable references that are ever assigned to null; * if `asserted` is not `null`, it contains `val` or `var` references that are known to be not null, after the tree finishes executing normally (non-exceptionally); * if `asserted` is `null`, the tree is know to terminate, by throwing, returning, or calling a function with `Nothing` type. Hence, it acts like a universal set. `alt` is defined as `<a1,r1>.alt(<a2,r2>) = <a1 intersect a2, r1 union r2>`. The difficult part is the `try ... catch ... finally ...`. We don't know at which point an exception is thrown in the body, and the catch cases may be not exhaustive, we have to collect any reference that is once retracted. Fix #21619
2 parents e6b4222 + 200c038 commit ee0dd7a

File tree

9 files changed

+257
-78
lines changed

9 files changed

+257
-78
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -777,13 +777,13 @@ object Contexts {
777777

778778
extension (c: Context)
779779
def addNotNullInfo(info: NotNullInfo) =
780-
c.withNotNullInfos(c.notNullInfos.extendWith(info))
780+
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c
781781

782782
def addNotNullRefs(refs: Set[TermRef]) =
783-
c.addNotNullInfo(NotNullInfo(refs, Set()))
783+
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c
784784

785785
def withNotNullInfos(infos: List[NotNullInfo]): Context =
786-
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
786+
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)
787787

788788
def relaxedOverrideContext: Context =
789789
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)

compiler/src/dotty/tools/dotc/typer/Applications.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ trait Applications extends Compatibility {
11341134
case _ => ()
11351135
else ()
11361136

1137-
fun1.tpe match {
1137+
val result = fun1.tpe match {
11381138
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
11391139
case TryDynamicCallType =>
11401140
val isInsertedApply = fun1 match {
@@ -1208,6 +1208,11 @@ trait Applications extends Compatibility {
12081208
else tryWithImplicitOnQualifier(fun1, proto).getOrElse(fail))
12091209
}
12101210
}
1211+
1212+
if result.tpe.isNothingType then
1213+
val nnInfo = result.notNullInfo
1214+
result.withNotNullInfo(nnInfo.terminatedInfo)
1215+
else result
12111216
}
12121217

12131218
/** Convert expression like

compiler/src/dotty/tools/dotc/typer/Nullables.scala

+76-40
Original file line numberDiff line numberDiff line change
@@ -52,34 +52,46 @@ object Nullables:
5252
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
5353
TypeBoundsTree(lo, hiTree, alias)
5454

55-
/** A set of val or var references that are known to be not null, plus a set of
56-
* variable references that are not known (anymore) to be not null
55+
/** A set of val or var references that are known to be not null
56+
* after the tree finishes executing normally (non-exceptionally),
57+
* plus a set of variable references that are ever assigned to null,
58+
* and may therefore be null if execution of the tree is interrupted
59+
* by an exception.
5760
*/
58-
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
59-
assert((asserted & retracted).isEmpty)
60-
61+
case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
6162
def isEmpty = this eq NotNullInfo.empty
6263

6364
def retractedInfo = NotNullInfo(Set(), retracted)
6465

66+
def terminatedInfo = NotNullInfo(null, retracted)
67+
6568
/** The sequential combination with another not-null info */
6669
def seq(that: NotNullInfo): NotNullInfo =
6770
if this.isEmpty then that
6871
else if that.isEmpty then this
69-
else NotNullInfo(
70-
this.asserted.union(that.asserted).diff(that.retracted),
71-
this.retracted.union(that.retracted).diff(that.asserted))
72+
else
73+
val newAsserted =
74+
if this.asserted == null || that.asserted == null then null
75+
else this.asserted.diff(that.retracted).union(that.asserted)
76+
val newRetracted = this.retracted.union(that.retracted)
77+
NotNullInfo(newAsserted, newRetracted)
7278

7379
/** The alternative path combination with another not-null info. Used to merge
74-
* the nullability info of the two branches of an if.
80+
* the nullability info of the branches of an if or match.
7581
*/
7682
def alt(that: NotNullInfo): NotNullInfo =
77-
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))
83+
val newAsserted =
84+
if this.asserted == null then that.asserted
85+
else if that.asserted == null then this.asserted
86+
else this.asserted.intersect(that.asserted)
87+
val newRetracted = this.retracted.union(that.retracted)
88+
NotNullInfo(newAsserted, newRetracted)
89+
end NotNullInfo
7890

7991
object NotNullInfo:
8092
val empty = new NotNullInfo(Set(), Set())
81-
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
82-
if asserted.isEmpty && retracted.isEmpty then empty
93+
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
94+
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
8395
else new NotNullInfo(asserted, retracted)
8496
end NotNullInfo
8597

@@ -223,7 +235,7 @@ object Nullables:
223235
*/
224236
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
225237
case info :: infos1 =>
226-
if info.asserted.contains(ref) then true
238+
if info.asserted == null || info.asserted.contains(ref) then true
227239
else if info.retracted.contains(ref) then false
228240
else infos1.impliesNotNull(ref)
229241
case _ =>
@@ -233,16 +245,15 @@ object Nullables:
233245
* or retractions in `info` supersede infos in existing entries of `infos`.
234246
*/
235247
def extendWith(info: NotNullInfo) =
236-
if info.isEmpty
237-
|| info.asserted.forall(infos.impliesNotNull(_))
238-
&& !info.retracted.exists(infos.impliesNotNull(_))
239-
then infos
248+
if info.isEmpty then infos
240249
else info :: infos
241250

242251
/** Retract all references to mutable variables */
243252
def retractMutables(using Context) =
244-
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
245-
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
253+
val mutables = infos.foldLeft(Set[TermRef]()):
254+
(ms, info) => ms.union(
255+
if info.asserted == null then Set.empty
256+
else info.asserted.filter(_.symbol.is(Mutable)))
246257
infos.extendWith(NotNullInfo(Set(), mutables))
247258

248259
end extension
@@ -304,15 +315,35 @@ object Nullables:
304315
extension (tree: Tree)
305316

306317
/* The `tree` with added nullability attachment */
307-
def withNotNullInfo(info: NotNullInfo): tree.type =
308-
if !info.isEmpty then tree.putAttachment(NNInfo, info)
318+
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
319+
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
309320
tree
310321

322+
/* Collect the nullability info from parts of `tree` */
323+
def collectNotNullInfo(using Context): NotNullInfo = tree match
324+
case Typed(expr, _) =>
325+
expr.notNullInfo
326+
case Apply(fn, args) =>
327+
val argsInfo = args.map(_.notNullInfo)
328+
val fnInfo = fn.notNullInfo
329+
argsInfo.foldLeft(fnInfo)(_ seq _)
330+
case TypeApply(fn, _) =>
331+
fn.notNullInfo
332+
case _ =>
333+
// Other cases are handled specially in typer.
334+
NotNullInfo.empty
335+
311336
/* The nullability info of `tree` */
312337
def notNullInfo(using Context): NotNullInfo =
313-
stripInlined(tree).getAttachment(NNInfo) match
314-
case Some(info) if !ctx.erasedTypes => info
315-
case _ => NotNullInfo.empty
338+
if !ctx.explicitNulls then NotNullInfo.empty
339+
else
340+
val tree1 = stripInlined(tree)
341+
tree1.getAttachment(NNInfo) match
342+
case Some(info) if !ctx.erasedTypes => info
343+
case _ =>
344+
val nnInfo = tree1.collectNotNullInfo
345+
tree1.withNotNullInfo(nnInfo)
346+
nnInfo
316347

317348
/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
318349
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =
@@ -393,21 +424,23 @@ object Nullables:
393424
end extension
394425

395426
extension (tree: Assign)
396-
def computeAssignNullable()(using Context): tree.type = tree.lhs match
397-
case TrackedRef(ref) =>
398-
val rhstp = tree.rhs.typeOpt
399-
if ctx.explicitNulls && ref.isNullableUnion then
400-
if rhstp.isNullType || rhstp.isNullableUnion then
401-
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
402-
// lhs variable is no longer trackable. We don't need to check whether the type `T`
403-
// is correct here, as typer will check it.
404-
tree.withNotNullInfo(NotNullInfo(Set(), Set(ref)))
405-
else
406-
// If the initial type is nullable and the assigned value is non-null,
407-
// we add it to the NotNull.
408-
tree.withNotNullInfo(NotNullInfo(Set(ref), Set()))
409-
else tree
410-
case _ => tree
427+
def computeAssignNullable()(using Context): tree.type =
428+
var nnInfo = tree.rhs.notNullInfo
429+
tree.lhs match
430+
case TrackedRef(ref) if ctx.explicitNulls && ref.isNullableUnion =>
431+
nnInfo = nnInfo.seq:
432+
val rhstp = tree.rhs.typeOpt
433+
if rhstp.isNullType || rhstp.isNullableUnion then
434+
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
435+
// lhs variable is no longer trackable. We don't need to check whether the type `T`
436+
// is correct here, as typer will check it.
437+
NotNullInfo(Set(), Set(ref))
438+
else
439+
// If the initial type is nullable and the assigned value is non-null,
440+
// we add it to the NotNull.
441+
NotNullInfo(Set(ref), Set())
442+
case _ =>
443+
tree.withNotNullInfo(nnInfo)
411444
end extension
412445

413446
private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!)
@@ -515,7 +548,10 @@ object Nullables:
515548
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
516549
&& ctx.notNullInfos.impliesNotNull(ref)
517550

518-
val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
551+
val retractedVars = ctx.notNullInfos.flatMap(info =>
552+
if info.asserted == null then Set.empty
553+
else info.asserted.filter(isRetracted)
554+
).toSet
519555
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
520556
end whileContext
521557

compiler/src/dotty/tools/dotc/typer/Typer.scala

+36-25
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12011201
untpd.unsplice(tree.expr).putAttachment(AscribedToUnit, ())
12021202
typed(tree.expr, underlyingTreeTpe.tpe.widenSkolem)
12031203
assignType(cpy.Typed(tree)(expr1, tpt), underlyingTreeTpe)
1204-
.withNotNullInfo(expr1.notNullInfo)
12051204
}
12061205

12071206
if (untpd.isWildcardStarArg(tree)) {
@@ -1551,11 +1550,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15511550

15521551
def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
15531552
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
1554-
result.withNotNullInfo(
1555-
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
1556-
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
1557-
else thenPathInfo.alt(elsePathInfo)
1558-
)
1553+
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
15591554
end typedIf
15601555

15611556
/** Decompose function prototype into a list of parameter prototypes and a result
@@ -2139,20 +2134,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
21392134
case1
21402135
}
21412136
.asInstanceOf[List[CaseDef]]
2142-
var nni = sel.notNullInfo
2143-
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
2144-
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt).withNotNullInfo(nni)
2137+
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
2138+
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
21452139
}
21462140

21472141
// Overridden in InlineTyper for inline matches
21482142
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
21492143
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
21502144
.asInstanceOf[List[CaseDef]]
2151-
var nni = sel.notNullInfo
2152-
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
2153-
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
2145+
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
2146+
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
21542147
}
21552148

2149+
private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
2150+
if cases.isEmpty then
2151+
// Empty cases is not allowed for match tree in the source code,
2152+
// but it can be generated by inlining: `tests/pos/i19198.scala`.
2153+
initInfo
2154+
else cases.map(_.notNullInfo).reduce(_.alt(_))
2155+
21562156
def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
21572157
var caseCtx = ctx
21582158
var wideSelType = wideSelType0
@@ -2241,7 +2241,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22412241
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
22422242
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
22432243
val expr1 = typed(tree.expr, bind1.symbol.info)
2244-
assignType(cpy.Labeled(tree)(bind1, expr1))
2244+
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
22452245
}
22462246

22472247
/** Type a case of a type match */
@@ -2291,7 +2291,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22912291
// Hence no adaptation is possible, and we assume WildcardType as prototype.
22922292
(from, proto)
22932293
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
2294-
assignType(cpy.Return(tree)(expr1, from))
2294+
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
22952295
end typedReturn
22962296

22972297
def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
@@ -2332,7 +2332,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23322332
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
23332333
untpd.Block(makeCanThrow(capabilityProof), expr)
23342334

2335-
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
2335+
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
2336+
var nnInfo = NotNullInfo.empty
23362337
val expr2 :: cases2x = harmonic(harmonize, pt) {
23372338
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
23382339
// uses the types of patterns in `tree.cases` to determine the capabilities.
@@ -2344,18 +2345,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23442345
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
23452346
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
23462347
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
2347-
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)
2348+
2349+
// Since we don't know at which point the the exception is thrown in the body,
2350+
// we have to collect any reference that is once retracted.
2351+
nnInfo = expr1.notNullInfo.retractedInfo
2352+
2353+
val casesCtx = ctx.addNotNullInfo(nnInfo)
23482354
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
23492355
expr1 :: cases1
23502356
}: @unchecked
23512357
val cases2 = cases2x.asInstanceOf[List[CaseDef]]
23522358

2353-
var nni = expr2.notNullInfo.retractedInfo
2354-
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
2355-
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
2356-
nni = nni.seq(finalizer1.notNullInfo)
2357-
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
2358-
}
2359+
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
2360+
// Therefore, the code in the finalizer and after the try block can only rely on the retracted
2361+
// info from the cases' body.
2362+
if cases2.nonEmpty then
2363+
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
2364+
2365+
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
2366+
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
2367+
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)
23592368

23602369
def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
23612370
val cases: List[untpd.CaseDef] = tree.handler match
@@ -2369,15 +2378,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23692378
def typedThrow(tree: untpd.Throw)(using Context): Tree =
23702379
val expr1 = typed(tree.expr, defn.ThrowableType)
23712380
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
2372-
val res = Throw(expr1).withSpan(tree.span)
2381+
var res = Throw(expr1).withSpan(tree.span)
23732382
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
23742383
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
23752384
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
2376-
Typed(res,
2385+
res = Typed(res,
23772386
TypeTree(
23782387
AnnotatedType(res.tpe,
23792388
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
2380-
else res
2389+
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)
23812390

23822391
def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
23832392
val elemProto = pt.stripNull().elemType match {
@@ -2842,6 +2851,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28422851
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
28432852
postProcessInfo(vdef1, sym)
28442853
vdef1.setDefTree
2854+
val nnInfo = rhs1.notNullInfo
2855+
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
28452856
}
28462857

28472858
private def retractDefDef(sym: Symbol)(using Context): Tree =

tests/explicit-nulls/neg/i21380b.scala

+18
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,22 @@ def test3(i: Int) =
1818
i match
1919
case 1 if x != null => ()
2020
case _ => x = " "
21+
x.trim() // ok
22+
23+
def test4(i: Int) =
24+
var x: String | Null = null
25+
var y: String | Null = null
26+
i match
27+
case 1 => x = "1"
28+
case _ => y = " "
29+
x.trim() // error
30+
31+
def test5(i: Int): String =
32+
var x: String | Null = null
33+
var y: String | Null = null
34+
i match
35+
case 1 => x = "1"
36+
case _ =>
37+
y = " "
38+
return y
2139
x.trim() // ok

tests/explicit-nulls/neg/i21380c.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test4: Int =
3232
case npe: NullPointerException => x = ""
3333
case _ => x = ""
3434
x.length // error
35-
// Although the catch block here is exhaustive,
36-
// it is possible that the exception is thrown and not caught.
37-
// Therefore, the code after the try block can only rely on the retracted info.
35+
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
36+
// and some exceptions are thrown and not caught. Therefore, the code in the finalizer and
37+
// after the try block can only rely on the retracted info from the cases' body.
3838

3939
def test5: Int =
4040
var x: String | Null = null

0 commit comments

Comments
 (0)