Skip to content

Commit 86565a4

Browse files
authored
Make reach refinement shallow (#19171)
This is to address the following soundness issue: ```scala trait File val useFile: [R] -> (path: String) -> (op: File^ -> R) -> R = ??? def main(): Unit = val f: [R] -> (path: String) -> (op: File^ -> R) -> R = useFile val g: [R] -> (path: String) -> (op: File^{f*} -> R) -> R = f // should be an error val leaked = g[File^{f*}]("test")(f => f) // boom ```
2 parents 98184f1 + 64242f4 commit 86565a4

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -236,35 +236,66 @@ extension (tp: Type)
236236
* (2) all covariant occurrences of cap replaced by `x*`, provided there
237237
* are no occurrences in `T` at other variances. (1) is standard, whereas
238238
* (2) is new.
239+
*
240+
* For (2), multiple-flipped covariant occurrences of cap won't be replaced.
241+
* In other words,
242+
*
243+
* - For xs: List[File^] ==> List[File^{xs*}], the cap is replaced;
244+
* - while f: [R] -> (op: File^ => R) -> R remains unchanged.
245+
*
246+
* Without this restriction, the signature of functions like withFile:
247+
*
248+
* (path: String) -> [R] -> (op: File^ => R) -> R
249+
*
250+
* could be refined to
251+
*
252+
* (path: String) -> [R] -> (op: File^{withFile*} => R) -> R
253+
*
254+
* which is clearly unsound.
239255
*
240256
* Why is this sound? Covariant occurrences of cap must represent capabilities
241257
* that are reachable from `x`, so they are included in the meaning of `{x*}`.
242258
* At the same time, encapsulation is still maintained since no covariant
243259
* occurrences of cap are allowed in instance types of type variables.
244260
*/
245261
def withReachCaptures(ref: Type)(using Context): Type =
246-
object narrowCaps extends TypeMap:
262+
class CheckContraCaps extends TypeTraverser:
247263
var ok = true
248-
def apply(t: Type) = t.dealias match
249-
case t1 @ CapturingType(p, cs) if cs.isUniversal =>
250-
if variance > 0 then
251-
t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet)
252-
else
253-
ok = false
254-
t
255-
case _ => t match
256-
case t @ CapturingType(p, cs) =>
257-
t.derivedCapturingType(apply(p), cs) // don't map capture set variables
258-
case t =>
259-
mapOver(t)
264+
def traverse(t: Type): Unit =
265+
if ok then
266+
t match
267+
case CapturingType(_, cs) if cs.isUniversal && variance <= 0 =>
268+
ok = false
269+
case _ =>
270+
traverseChildren(t)
271+
272+
object narrowCaps extends TypeMap:
273+
/** Has the variance been flipped at this point? */
274+
private var isFlipped: Boolean = false
275+
276+
def apply(t: Type) =
277+
val saved = isFlipped
278+
try
279+
if variance <= 0 then isFlipped = true
280+
t.dealias match
281+
case t1 @ CapturingType(p, cs) if cs.isUniversal && !isFlipped =>
282+
t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet)
283+
case _ => t match
284+
case t @ CapturingType(p, cs) =>
285+
t.derivedCapturingType(apply(p), cs) // don't map capture set variables
286+
case t =>
287+
mapOver(t)
288+
finally isFlipped = saved
260289
ref match
261290
case ref: CaptureRef if ref.isTrackableRef =>
262-
val tp1 = narrowCaps(tp)
263-
if narrowCaps.ok then
291+
val checker = new CheckContraCaps
292+
checker.traverse(tp)
293+
if checker.ok then
294+
val tp1 = narrowCaps(tp)
264295
if tp1 ne tp then capt.println(i"narrow $tp of $ref to $tp1")
265296
tp1
266297
else
267-
capt.println(i"cannot narrow $tp of $ref to $tp1")
298+
capt.println(i"cannot narrow $tp of $ref")
268299
tp
269300
case _ =>
270301
tp
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import language.experimental.captureChecking
2+
trait IO
3+
def test1(): Unit =
4+
val f: IO^ => IO^ = x => x
5+
val g: IO^ => IO^{f*} = f // error
6+
def test2(): Unit =
7+
val f: [R] -> (IO^ => R) -> R = ???
8+
val g: [R] -> (IO^{f*} => R) -> R = f // error
9+
def test3(): Unit =
10+
val f: [R] -> (IO^ -> R) -> R = ???
11+
val g: [R] -> (IO^{f*} -> R) -> R = f // error
12+
def test4(): Unit =
13+
val xs: List[IO^] = ???
14+
val ys: List[IO^{xs*}] = xs // ok
15+
def test5(): Unit =
16+
val f: [R] -> (IO^ -> R) -> IO^ = ???
17+
val g: [R] -> (IO^ -> R) -> IO^{f*} = f // ok
18+
val h: [R] -> (IO^{f*} -> R) -> IO^ = f // error
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import language.experimental.captureChecking
2+
3+
trait File
4+
val useFile: [R] -> (path: String) -> (op: File^ -> R) -> R = ???
5+
def main(): Unit =
6+
val f: [R] -> (path: String) -> (op: File^ -> R) -> R = useFile
7+
val g: [R] -> (path: String) -> (op: File^{f*} -> R) -> R = f // error
8+
val leaked = g[File^{f*}]("test")(f => f) // boom

0 commit comments

Comments
 (0)