Skip to content

Experiment with Capless-like Scheme for Capture Checking #23291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions compiler/src/dotty/tools/dotc/cc/Capability.scala
Original file line number Diff line number Diff line change
Expand Up @@ -473,27 +473,28 @@ object Capabilities:
case info: OrType => viaInfo(info.tp1)(test) && viaInfo(info.tp2)(test)
case _ => false

def trySubpath(y: TermRef): Boolean =
y.prefix.match
case ypre: Capability =>
this.subsumes(ypre)
|| this.match
case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol =>
// To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y`
// are equvalent, which means `x =:= y` in terms of subtyping,
// not just `{x} =:= {y}` in terms of subcapturing.
// It is possible to construct two singleton types `x` and `y`,
// which subsume each other, but are not equal references.
// See `tests/neg-custom-args/captures/path-prefix.scala` for example.
withMode(Mode.IgnoreCaptures):
TypeComparer.isSameRef(xpre, ypre)
case _ =>
false
case _ => false

try (this eq y)
|| maxSubsumes(y, canAddHidden = !vs.isOpen)
|| y.match
case y: TermRef =>
y.prefix.match
case ypre: Capability =>
this.subsumes(ypre)
|| this.match
case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol =>
// To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y`
// are equvalent, which means `x =:= y` in terms of subtyping,
// not just `{x} =:= {y}` in terms of subcapturing.
// It is possible to construct two singleton types `x` and `y`,
// which subsume each other, but are not equal references.
// See `tests/neg-custom-args/captures/path-prefix.scala` for example.
withMode(Mode.IgnoreCaptures):
TypeComparer.isSameRef(xpre, ypre)
case _ =>
false
case _ => false
|| viaInfo(y.info)(subsumingRefs(this, _))
case y: TermRef => trySubpath(y) || viaInfo(y.info)(subsumingRefs(this, _))
case Maybe(y1) => this.stripMaybe.subsumes(y1)
case ReadOnly(y1) => this.stripReadOnly.subsumes(y1)
case y: TypeRef if y.derivesFrom(defn.Caps_CapSet) =>
Expand All @@ -507,6 +508,15 @@ object Capabilities:
this.subsumes(hi)
case _ =>
y.captureSetOfInfo.elems.forall(this.subsumes)
case Reach(y1: TermRef) =>
val sym = y1.symbol
def isUseClassParam: Boolean =
sym.owner match
case classSym: ClassSymbol =>
val paramSym = classSym.primaryConstructor.paramNamed(sym.name)
paramSym.isUseParam
case _ => false
isUseClassParam && trySubpath(y1)
case _ => false
|| this.match
case Reach(x1) => x1.subsumes(y.stripReach)
Expand Down Expand Up @@ -858,4 +868,4 @@ object Capabilities:
case tp1 => tp1
end toResultInResults

end Capabilities
end Capabilities
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ extension (tp: Type)
val tp1 = narrowCaps(tp)
if narrowCaps.change then
capt.println(i"narrow $tp of $ref to $tp1")
//println(i"reach refinement $tp at $ref to $tp1 (${ctx.compilationUnit})")
tp1
else
tp
Expand All @@ -395,6 +396,9 @@ extension (tp: Type)
RefinedType(tp, name,
AnnotatedType(rinfo, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan)))

def dropUseAndConsumeAnnots(using Context): Type =
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)

extension (tp: MethodType)
/** A method marks an existential scope unless it is the prefix of a curried method */
def marksExistentialScope(using Context): Boolean =
Expand Down Expand Up @@ -490,18 +494,24 @@ extension (sym: Symbol)
def hasTrackedParts(using Context): Boolean =
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty

/** `sym` is annotated @use or it is a type parameter with a matching
/** `sym` itself or its info is annotated @use or it is a type parameter with a matching
* @use-annotated term parameter that contains `sym` in its deep capture set.
*/
def isUseParam(using Context): Boolean =
sym.hasAnnotation(defn.UseAnnot)
|| sym.info.hasAnnotation(defn.UseAnnot)
|| sym.is(TypeParam)
&& sym.owner.rawParamss.nestedExists: param =>
param.is(TermParam) && param.hasAnnotation(defn.UseAnnot)
&& param.info.deepCaptureSet.elems.exists:
case c: TypeRef => c.symbol == sym
case _ => false

/** `sym` or its info is annotated with `@consume`. */
def isConsumeParam(using Context): Boolean =
sym.hasAnnotation(defn.ConsumeAnnot)
|| sym.info.hasAnnotation(defn.ConsumeAnnot)

def isUpdateMethod(using Context): Boolean =
sym.isAllOf(Mutable | Method, butNot = Accessor)

Expand Down
28 changes: 12 additions & 16 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -708,19 +708,6 @@ class CheckCaptures extends Recheck, SymTransformer:
selType
}//.showing(i"recheck sel $tree, $qualType = $result")

/** Hook for massaging a function before it is applied. Copies all @use and @consume
* annotations on method parameter symbols to the corresponding paramInfo types.
*/
override def prepareFunction(funtpe: MethodType, meth: Symbol)(using Context): MethodType =
val paramInfosWithUses =
funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
val param = meth.paramNamed(pname)
def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match
case Some(ann) => AnnotatedType(tp, ann)
case _ => tp
copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot)
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)

/** Recheck applications, with special handling of unsafeAssumePure.
* More work is done in `recheckApplication`, `recheckArg` and `instantiate` below.
*/
Expand Down Expand Up @@ -748,7 +735,8 @@ class CheckCaptures extends Recheck, SymTransformer:
val argType = recheck(arg, freshenedFormal)
.showing(i"recheck arg $arg vs $freshenedFormal = $result", capt)
if formal.hasAnnotation(defn.UseAnnot) || formal.hasAnnotation(defn.ConsumeAnnot) then
// The @use and/or @consume annotation is added to `formal` by `prepareFunction`
// The @use and/or @consume annotation is added to `formal` when creating methods types.
// See [[MethodTypeCompanion.adaptParamInfo]].
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
markFree(argType.deepCaptureSet, arg)
if formal.containsCap then
Expand Down Expand Up @@ -789,6 +777,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case appType @ CapturingType(appType1, refs)
if qualType.exists
&& !tree.fun.symbol.isConstructor
&& funType.paramInfos.isEmpty
&& qualCaptures.mightSubcapture(refs)
&& argCaptures.forall(_.mightSubcapture(refs)) =>
val callCaptures = argCaptures.foldLeft(qualCaptures)(_ ++ _)
Expand Down Expand Up @@ -845,10 +834,14 @@ class CheckCaptures extends Recheck, SymTransformer:
initCs ++ FreshCap(Origin.NewCapability(core)).readOnly.singletonCaptureSet
else initCs
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
val paramSym = cls.primaryConstructor.paramNamed(getterName)
val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol
if !getter.is(Private) && getter.hasTrackedParts then
refined = refined.refinedOverride(getterName, argType.unboxed) // Yichen you might want to check this
allCaptures ++= argType.captureSet
if paramSym.isUseParam then
allCaptures ++= argType.deepCaptureSet
else
allCaptures ++= argType.captureSet
(refined, allCaptures)

/** Augment result type of constructor with refinements and captures.
Expand Down Expand Up @@ -1616,7 +1609,10 @@ class CheckCaptures extends Recheck, SymTransformer:
if noWiden(actual, expected) then
actual
else
val improvedVAR = improveCaptures(actual.widen.dealiasKeepAnnots, actual)
// Compute the widened type. Drop `@use` and `@consume` annotations from the type,
// since they obscures the capturing type.
val widened = actual.widen.dealiasKeepAnnots.dropUseAndConsumeAnnots
val improvedVAR = improveCaptures(widened, actual)
val improved = improveReadOnly(improvedVAR, expected)
val adapted = adaptBoxed(
improved.withReachCaptures(actual), expected, tree,
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/cc/SepCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.maybeOwner.enclosingMethodOrClass) then
report.error(em"""Separation failure: $descr non-local $refSym""", pos)
else if refSym.is(TermParam)
&& !refSym.hasAnnotation(defn.ConsumeAnnot)
&& !refSym.isConsumeParam
&& currentOwner.isContainedIn(refSym.owner)
then
badParams += refSym
Expand Down Expand Up @@ -899,7 +899,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if !isUnsafeAssumeSeparate(tree) then trace(i"checking separate $tree"):
checkUse(tree)
tree match
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.hasAnnotation(defn.ConsumeAnnot) =>
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.isConsumeParam =>
traverseChildren(tree)
checkConsumedRefs(
captures(qual).footprint(), qual.nuType,
Expand Down Expand Up @@ -962,4 +962,4 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
consumeInLoopError(ref, pos)
case _ =>
traverseChildren(tree)
end SepCheck
end SepCheck
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ class Definitions {

// Set of annotations that are not printed in types except under -Yprint-debug
@tu lazy val SilentAnnots: Set[Symbol] =
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot, SilentIntoAnnot)
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot, SilentIntoAnnot, UseAnnot, ConsumeAnnot)

// A list of annotations that are commonly used to indicate that a field/method argument or return
// type is not null. These annotations are used by the nullification logic in JavaNullInterop to
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4234,6 +4234,11 @@ object Types extends TypeUtils {
paramType = addAnnotation(paramType, defn.InlineParamAnnot, param)
if param.is(Erased) then
paramType = addAnnotation(paramType, defn.ErasedParamAnnot, param)
// Copy `@use` and `@consume` annotations from parameter symbols to the type.
if param.hasAnnotation(defn.UseAnnot) then
paramType = addAnnotation(paramType, defn.UseAnnot, param)
if param.hasAnnotation(defn.ConsumeAnnot) then
paramType = addAnnotation(paramType, defn.ConsumeAnnot, param)
paramType

def adaptParamInfo(param: Symbol)(using Context): Type =
Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class PlainPrinter(_ctx: Context) extends Printer {

protected def argText(arg: Type, isErased: Boolean = false): Text =
keywordText("erased ").provided(isErased)
~ specialAnnotText(defn.UseAnnot, arg)
~ specialAnnotText(defn.ConsumeAnnot, arg)
~ homogenizeArg(arg).match
case arg: TypeBounds => "?" ~ toText(arg)
case arg => toText(arg)
Expand Down Expand Up @@ -376,10 +378,18 @@ class PlainPrinter(_ctx: Context) extends Printer {
try "(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")"
finally elideCapabilityCaps = saved

/** Print the annotation that are meant to be on the parameter symbol but was moved
* to parameter types. Examples are `@use` and `@consume`. */
protected def specialAnnotText(sym: ClassSymbol, tp: Type): Text =
Str(s"@${sym.name} ").provided(tp.hasAnnotation(sym))

protected def paramsText(lam: LambdaType): Text = {
def paramText(ref: ParamRef) =
val erased = ref.underlying.hasAnnotation(defn.ErasedParamAnnot)
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ hashStr(lam) ~ toTextRHS(ref.underlying, isParameter = true)
keywordText("erased ").provided(erased)
~ specialAnnotText(defn.UseAnnot, ref.underlying)
~ specialAnnotText(defn.ConsumeAnnot, ref.underlying)
~ ParamRefNameString(ref) ~ hashStr(lam) ~ toTextRHS(ref.underlying, isParameter = true)
Text(lam.paramRefs.map(paramText), ", ")
}

Expand Down
6 changes: 3 additions & 3 deletions scala2-library-cc/src/scala/collection/Iterable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,9 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable

def map[B](f: A => B): CC[B]^{this, f} = iterableFactory.from(new View.Map(this, f))

def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} = iterableFactory.from(new View.FlatMap(this, f))
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} = iterableFactory.from(new View.FlatMap(this, f))

def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this} = flatMap(asIterable)
def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this, asIterable*} = flatMap(asIterable)

def collect[B](pf: PartialFunction[A, B]^): CC[B]^{this, pf} =
iterableFactory.from(new View.Collect(this, pf))
Expand Down Expand Up @@ -902,7 +902,7 @@ object IterableOps {
def map[B](f: A => B): CC[B]^{this, f} =
self.iterableFactory.from(new View.Map(filtered, f))

def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} =
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} =
self.iterableFactory.from(new View.FlatMap(filtered, f))

def foreach[U](f: A => U): Unit = filtered.foreach(f)
Expand Down
5 changes: 2 additions & 3 deletions scala2-library-cc/src/scala/collection/IterableOnce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,9 @@ final class IterableOnceExtensionMethods[A](private val it: IterableOnce[A]) ext
}

@deprecated("Use .iterator.flatMap instead or consider requiring an Iterable", "2.13.0")
def flatMap[B](f: A => IterableOnce[B]^): IterableOnce[B]^{f} = it match {
def flatMap[B](@caps.use f: A => IterableOnce[B]^): IterableOnce[B]^{f*} = it match
case it: Iterable[A] => it.flatMap(f)
case _ => it.iterator.flatMap(f)
}

@deprecated("Use .iterator.sameElements instead", "2.13.0")
def sameElements[B >: A](that: IterableOnce[B]): Boolean = it.iterator.sameElements(that)
Expand Down Expand Up @@ -439,7 +438,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
* @return a new $coll resulting from applying the given collection-valued function
* `f` to each element of this $coll and concatenating the results.
*/
def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f}
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*}

/** Converts this $coll of iterable collections into
* a $coll formed by the elements of these iterable
Expand Down
10 changes: 5 additions & 5 deletions scala2-library-cc/src/scala/collection/Iterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,8 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
def next() = f(self.next())
}

def flatMap[B](f: A => IterableOnce[B]^): Iterator[B]^{this, f} = new AbstractIterator[B] {
private[this] var cur: Iterator[B]^{f} = Iterator.empty
def flatMap[B](@caps.use f: A => IterableOnce[B]^): Iterator[B]^{this, f*} = new AbstractIterator[B] {
private[this] var cur: Iterator[B]^{f*} = Iterator.empty
/** Trillium logic boolean: -1 = unknown, 0 = false, 1 = true */
private[this] var _hasNext: Int = -1

Expand Down Expand Up @@ -623,7 +623,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
}
}

def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this} =
def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this, ev*} =
flatMap[B](ev)

def concat[B >: A](xs: => IterableOnce[B]^): Iterator[B]^{this, xs} = new Iterator.ConcatIterator[B](self).concat(xs)
Expand Down Expand Up @@ -982,7 +982,7 @@ object Iterator extends IterableFactory[Iterator] {
/** Creates a target $coll from an existing source collection
*
* @param source Source collection
* @tparam A the type of the collections elements
* @tparam A the type of the collection's elements
* @return a new $coll with the elements of `source`
*/
override def from[A](source: IterableOnce[A]^): Iterator[A]^{source} = source.iterator
Expand All @@ -1003,7 +1003,7 @@ object Iterator extends IterableFactory[Iterator] {

/**
* @return A builder for $Coll objects.
* @tparam A the type of the ${coll}s elements
* @tparam A the type of the ${coll}'s elements
*/
def newBuilder[A]: Builder[A, Iterator[A]] =
new ImmutableBuilder[A, Iterator[A]](empty[A]) {
Expand Down
4 changes: 2 additions & 2 deletions scala2-library-cc/src/scala/collection/Map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ trait MapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C]
* @return a new $coll resulting from applying the given collection-valued function
* `f` to each element of this $coll and concatenating the results.
*/
def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = mapFactory.from(new View.FlatMap(this, f))
def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = mapFactory.from(new View.FlatMap(this, f))

/** Returns a new $coll containing the elements from the left hand operand followed by the elements from the
* right hand operand. The element type of the $coll is the most specific superclass encompassing
Expand Down Expand Up @@ -383,7 +383,7 @@ object MapOps {
def map[K2, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2]^{this, f} =
self.mapFactory.from(new View.Map(filtered, f))

def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2]^{this, f} =
def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2]^{this, f*} =
self.mapFactory.from(new View.FlatMap(filtered, f))

override def withFilter(q: ((K, V)) => Boolean): WithFilter[K, V, IterableCC, CC]^{this, q} =
Expand Down
2 changes: 1 addition & 1 deletion scala2-library-cc/src/scala/collection/SortedMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ object SortedMapOps {
def map[K2 : Ordering, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2] =
self.sortedMapFactory.from(new View.Map(filtered, f))

def flatMap[K2 : Ordering, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
def flatMap[K2 : Ordering, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
self.sortedMapFactory.from(new View.FlatMap(filtered, f))

override def withFilter(q: ((K, V)) => Boolean): WithFilter[K, V, IterableCC, MapCC, CC]^{this, q} =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ trait StrictOptimizedIterableOps[+A, +CC[_], +C]
b.result()
}

override def flatMap[B](f: A => IterableOnce[B]^): CC[B] =
override def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B] =
strictOptimizedFlatMap(iterableFactory.newBuilder, f)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait StrictOptimizedMapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C
override def map[K2, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2] =
strictOptimizedMap(mapFactory.newBuilder, f)

override def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
override def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
strictOptimizedFlatMap(mapFactory.newBuilder, f)

override def concat[V2 >: V](suffix: IterableOnce[(K, V2)]^): CC[K, V2] =
Expand Down
4 changes: 2 additions & 2 deletions scala2-library-cc/src/scala/collection/View.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ object View extends IterableFactory[View] {

/** A view that flatmaps elements of the underlying collection. */
@SerialVersionUID(3L)
class FlatMap[A, B](underlying: SomeIterableOps[A]^, f: A => IterableOnce[B]^) extends AbstractView[B] {
def iterator: Iterator[B]^{underlying, f} = underlying.iterator.flatMap(f)
class FlatMap[A, B](underlying: SomeIterableOps[A]^, @caps.use f: A => IterableOnce[B]^) extends AbstractView[B] {
def iterator: Iterator[B]^{underlying, f*} = underlying.iterator.flatMap(f)
override def knownSize: Int = if (underlying.knownSize == 0) 0 else super.knownSize
override def isEmpty: Boolean = iterator.isEmpty
}
Expand Down
Loading
Loading