Skip to content

Change handling of curried function types in capture checking #18131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 15, 2023
Merged
72 changes: 42 additions & 30 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ sealed abstract class CaptureSet extends Showable:
* any of the elements in the constant capture set `that`
*/
def -- (that: CaptureSet.Const)(using Context): CaptureSet =
val elems1 = elems.filter(!that.accountsFor(_))
if elems1.size == elems.size then this
else if this.isConst then Const(elems1)
else Diff(asVar, that)
if this.isConst then
val elems1 = elems.filter(!that.accountsFor(_))
if elems1.size == elems.size then this else Const(elems1)
else
if that.isAlwaysEmpty then this else Diff(asVar, that)

/** The largest subset (via <:<) of this capture set that does not account for `ref` */
def - (ref: CaptureRef)(using Context): CaptureSet =
Expand Down Expand Up @@ -845,34 +846,45 @@ object CaptureSet:
/** The capture set of the type underlying a CaptureRef */
def ofInfo(ref: CaptureRef)(using Context): CaptureSet = ref match
case ref: TermRef if ref.isRootCapability => ref.singletonCaptureSet
case _ => ofType(ref.underlying)
case _ => ofType(ref.underlying, followResult = true)

/** Capture set of a type */
def ofType(tp: Type)(using Context): CaptureSet =
def recur(tp: Type): CaptureSet = tp.dealias match
case tp: TermRef =>
tp.captureSet
case tp: TermParamRef =>
tp.captureSet
case _: TypeRef =>
if tp.classSymbol.hasAnnotation(defn.CapabilityAnnot) then universal else empty
case _: TypeParamRef =>
empty
case CapturingType(parent, refs) =>
recur(parent) ++ refs
case AppliedType(tycon, args) =>
val cs = recur(tycon)
tycon.typeParams match
case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args)
case _ => cs
case tp: TypeProxy =>
recur(tp.underlying)
case AndType(tp1, tp2) =>
recur(tp1) ** recur(tp2)
case OrType(tp1, tp2) =>
recur(tp1) ++ recur(tp2)
case _ =>
empty
def ofType(tp: Type, followResult: Boolean)(using Context): CaptureSet =
def recur(tp: Type): CaptureSet = trace(i"ofType $tp, ${tp.getClass} $followResult", show = true):
tp.dealias match
case tp: TermRef =>
tp.captureSet
case tp: TermParamRef =>
tp.captureSet
case _: TypeRef =>
if tp.classSymbol.hasAnnotation(defn.CapabilityAnnot) then universal else empty
case _: TypeParamRef =>
empty
case CapturingType(parent, refs) =>
recur(parent) ++ refs
case tpd @ RefinedType(parent, _, rinfo: MethodType)
if followResult && defn.isFunctionType(tpd) =>
ofType(parent, followResult = false) // pick up capture set from parent type
++ (recur(rinfo.resType) // add capture set of result
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
case tpd @ AppliedType(tycon, args) =>
if followResult && defn.isNonRefinedFunction(tpd) then
recur(args.last)
// must be (pure) FunctionN type since ImpureFunctions have already
// been eliminated in selector's dealias. Use capture set of result.
else
val cs = recur(tycon)
tycon.typeParams match
case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args)
case _ => cs
case tp: TypeProxy =>
recur(tp.underlying)
case AndType(tp1, tp2) =>
recur(tp1) ** recur(tp2)
case OrType(tp1, tp2) =>
recur(tp1) ++ recur(tp2)
case _ =>
empty
recur(tp)
.showing(i"capture set of $tp = $result", capt)

Expand Down
Loading