Skip to content

Commit f6389d0

Browse files
committed
Remove curried function types abbreviations
Remove automatic insertion of captured in curried function types from left to right. They were sometimes confusing and with deep capture sets are counter-productive now.
1 parent 8030851 commit f6389d0

File tree

7 files changed

+67
-164
lines changed

7 files changed

+67
-164
lines changed

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -254,19 +254,36 @@ extends tpd.TreeTraverser:
254254
val tp1 = mapInferred(tp)
255255
if boxed then box(tp1) else tp1
256256

257-
/** Expand some aliases of function types to the underlying functions.
258-
* Right now, these are only $throws aliases, but this could be generalized.
259-
*/
260-
private def expandThrowsAlias(tp: Type)(using Context) = tp match
261-
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
262-
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
263-
defn.FunctionOf(
264-
AnnotatedType(
257+
/** Recognizer for `res $throws exc`, returning `(res, exc)` in case of success */
258+
object throwsAlias:
259+
def unapply(tp: Type)(using Context): Option[(Type, Type)] = tp match
260+
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
261+
Some((res, exc))
262+
case _ =>
263+
None
264+
265+
/** Expand $throws aliases. This is hard-coded here since $throws aliases in stdlib
266+
* are defined with `?=>` rather than `?->`.
267+
* We also have to add a capture set to the last expanded throws alias. I.e.
268+
* T $throws E1 $throws E2
269+
* expands to
270+
* (erased x$0: CanThrow[E1]) ?-> (erased x$1: CanThrow[E1]) ?->{x$0} T
271+
*/
272+
private def expandThrowsAlias(tp: Type, encl: List[MethodType] = Nil)(using Context): Type = tp match
273+
case throwsAlias(res, exc) =>
274+
val paramType = AnnotatedType(
265275
defn.CanThrowClass.typeRef.appliedTo(exc),
266-
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil,
267-
res,
268-
isContextual = true
269-
)
276+
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span))
277+
val isLast = throwsAlias.unapply(res).isEmpty
278+
val paramName = nme.syntheticParamName(encl.length)
279+
val mt = ContextualMethodType(paramName :: Nil)(
280+
_ => paramType :: Nil,
281+
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
282+
val fntpe = RefinedType(defn.ErasedFunctionClass.typeRef, nme.apply, mt)
283+
if !encl.isEmpty && isLast then
284+
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
285+
CapturingType(fntpe, cs, boxed = false)
286+
else fntpe
270287
case _ => tp
271288

272289
private def expandThrowsAliases(using Context) = new TypeMap:
@@ -283,70 +300,10 @@ extends tpd.TreeTraverser:
283300
case _ =>
284301
mapOver(t)
285302

286-
/** Fill in capture sets of curried function types from left to right, using
287-
* a combination of the following two rules:
288-
*
289-
* 1. Expand `{c} (x: A) -> (y: B) -> C`
290-
* to `{c} (x: A) -> {c} (y: B) -> C`
291-
* 2. Expand `(x: A) -> (y: B) -> C` where `x` is tracked
292-
* to `(x: A) -> {x} (y: B) -> C`
293-
*
294-
* TODO: Should we also propagate capture sets to the left?
295-
*/
296-
private def expandAbbreviations(using Context) = new TypeMap:
297-
298-
/** Propagate `outerCs` as well as all tracked parameters as capture set to the result type
299-
* of the dependent function type `tp`.
300-
*/
301-
def propagateDepFunctionResult(tp: Type, outerCs: CaptureSet): Type = tp match
302-
case RefinedType(parent, nme.apply, rinfo: MethodType) =>
303-
val localCs = CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)
304-
val rinfo1 = rinfo.derivedLambdaType(
305-
resType = propagateEnclosing(rinfo.resType, CaptureSet.empty, outerCs ++ localCs))
306-
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
307-
else tp
308-
309-
/** If `tp` is a function type:
310-
* - add `outerCs` as its capture set,
311-
* - propagate `currentCs`, `outerCs`, and all tracked parameters of `tp` to the right.
312-
*/
313-
def propagateEnclosing(tp: Type, currentCs: CaptureSet, outerCs: CaptureSet): Type = tp match
314-
case tp @ AppliedType(tycon, args) if defn.isFunctionClass(tycon.typeSymbol) =>
315-
val tycon1 = this(tycon)
316-
val args1 = args.init.mapConserve(this)
317-
val tp1 =
318-
if args1.exists(!_.captureSet.isAlwaysEmpty) then
319-
val propagated = propagateDepFunctionResult(
320-
depFun(tycon, args1, args.last), currentCs ++ outerCs)
321-
propagated match
322-
case RefinedType(_, _, mt: MethodType) =>
323-
if mt.isCaptureDependent then propagated
324-
else
325-
// No need to introduce dependent type, switch back to generic function type
326-
tp.derivedAppliedType(tycon1, args1 :+ mt.resType)
327-
else
328-
val resType1 = propagateEnclosing(
329-
args.last, CaptureSet.empty, currentCs ++ outerCs)
330-
tp.derivedAppliedType(tycon1, args1 :+ resType1)
331-
tp1.capturing(outerCs)
332-
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
333-
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
334-
.capturing(outerCs)
335-
case _ =>
336-
mapOver(tp)
337-
338-
def apply(tp: Type): Type = tp match
339-
case CapturingType(parent, cs) =>
340-
tp.derivedCapturingType(propagateEnclosing(parent, cs, CaptureSet.empty), cs)
341-
case _ =>
342-
propagateEnclosing(tp, CaptureSet.empty, CaptureSet.empty)
343-
end expandAbbreviations
344-
345303
private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type =
346304
val tp1 = expandThrowsAliases(if boxed then box(tp) else tp)
347305
if tp1 ne tp then capt.println(i"expanded: $tp --> $tp1")
348-
if ctx.settings.YccNoAbbrev.value then tp1
349-
else expandAbbreviations(tp1)
306+
tp1
350307

351308
/** Transform type of type tree, and remember the transformed type as the type the tree */
352309
private def transformTT(tree: TypeTree, boxed: Boolean, exact: Boolean)(using Context): Unit =

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ private sealed trait YSettings:
387387
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation.")
388388
val YrecheckTest: Setting[Boolean] = BooleanSetting("-Yrecheck-test", "Run basic rechecking (internal test only).")
389389
val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Used in conjunction with captureChecking language import, debug info for captured references.")
390-
val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with captureChecking language import, suppress type abbreviations.")
391390

392391
/** Area-specific debug output */
393392
val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.")

tests/neg-custom-args/captures/curried-simplified.check

Lines changed: 0 additions & 42 deletions
This file was deleted.

tests/neg-custom-args/captures/curried-simplified.scala

Lines changed: 0 additions & 21 deletions
This file was deleted.
Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
import java.io.*
2-
import annotation.capability
1+
object Test:
2+
def map2(xs: List[Int])(f: Int => Int): List[Int] = xs.map(f)
3+
val f1 = map2
4+
val fc1: List[Int] -> (Int => Int) -> List[Int] = f1
5+
6+
def map3(f: Int => Int)(xs: List[Int]): List[Int] = xs.map(f)
7+
private val f2 = map3
8+
val fc2: (f: Int => Int) -> List[Int] ->{f} List[Int] = f2
9+
10+
val f3 = (f: Int => Int) =>
11+
println(f(3))
12+
(xs: List[Int]) => xs.map(_ + 1)
13+
val f3c: (Int => Int) -> List[Int] -> List[Int] = f3
14+
15+
class LL[A]:
16+
def drop(n: Int): LL[A]^{this} = ???
317

18+
def test(ct: CanThrow[Exception]) =
19+
def xs: LL[Int]^{ct} = ???
20+
val ys = xs.drop(_)
21+
val ysc: Int -> LL[Int]^{ct} = ys
22+
23+
import java.io.*
424
def Test4(g: OutputStream^) =
525
val xs: List[Int] = ???
626
val later = (f: OutputStream^) => (y: Int) => xs.foreach(x => f.write(x + y))
727
val _: (f: OutputStream^) ->{} Int ->{f} Unit = later
828

929
val later2 = () => (y: Int) => xs.foreach(x => g.write(x + y))
1030
val _: () ->{} Int ->{g} Unit = later2
31+

tests/pos-custom-args/captures/curried-shorthands.scala

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/pos-custom-args/captures/i13816.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ import language.experimental.saferExceptions
22

33
class Ex1 extends Exception("Ex1")
44
class Ex2 extends Exception("Ex2")
5+
class Ex3 extends Exception("Ex3")
56

67
def foo0(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit =
78
if i > 0 then throw new Ex1 else throw new Ex2
89

9-
def foo01(i: Int): CanThrow[Ex1] ?-> CanThrow[Ex2] ?-> Unit =
10+
/* Does not work yet since annotated CFTs are not recognized properly in typer
11+
12+
def foo01(i: Int): (ct: CanThrow[Ex1]) ?-> CanThrow[Ex2] ?->{ct} Unit =
1013
if i > 0 then throw new Ex1 else throw new Ex2
14+
*/
1115

1216
def foo1(i: Int): Unit throws Ex1 throws Ex2 =
1317
if i > 0 then throw new Ex1 else throw new Ex1
@@ -33,6 +37,15 @@ def foo7(i: Int)(using CanThrow[Ex1]): Unit throws Ex1 | Ex2 =
3337
def foo8(i: Int)(using CanThrow[Ex2]): Unit throws Ex2 | Ex1 =
3438
if i > 0 then throw new Ex1 else throw new Ex2
3539

40+
/** Does not work yet since the type of the rhs is not hygienic
41+
42+
def foo9(i: Int): Unit throws Ex1 | Ex2 | Ex3 =
43+
if i > 0 then throw new Ex1
44+
else if i < 0 then throw new Ex2
45+
else throw new Ex3
46+
47+
*/
48+
3649
def test(): Unit =
3750
try
3851
foo1(1)

0 commit comments

Comments
 (0)