Skip to content

Shallow capture sets #12875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 28, 2021
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Types._
import Scopes._
import Names.Name
import Denotations.Denotation
import typer.Typer
import typer.{Typer, RefineTypes}
import typer.ImportInfo._
import Decorators._
import io.{AbstractFile, PlainFile, VirtualFile}
Expand Down Expand Up @@ -204,7 +204,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
val profileBefore = profiler.beforePhase(phase)
units = phase.runOn(units)
profiler.afterPhase(phase, profileBefore)
if (ctx.settings.Xprint.value.containsPhase(phase))
if ctx.settings.Xprint.value.containsPhase(phase) && !phase.isInstanceOf[RefineTypes] then
for (unit <- units)
lastPrintedTree =
printTree(lastPrintedTree)(using ctx.fresh.setPhase(phase.next).setCompilationUnit(unit))
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ trait AllScalaSettings extends CommonScalaSettings { self: Settings.SettingGroup
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation")
val YrefineTypes: Setting[Boolean] = BooleanSetting("-Yrefine-types", "Run experimental type refiner (test only)")
val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references")
val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with -Ycc, suppress type abbreviations")

/** Area-specific debug output */
val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.")
Expand Down
99 changes: 40 additions & 59 deletions compiler/src/dotty/tools/dotc/core/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,17 @@ case class CaptureSet private (elems: CaptureSet.Refs) extends Showable:
def isEmpty: Boolean = elems.isEmpty
def nonEmpty: Boolean = !isEmpty

private var myClosure: Refs | Null = null

def closure(using Context): Refs =
if myClosure == null then
var cl = elems
var seen: Refs = SimpleIdentitySet.empty
while
val prev = cl
for ref <- cl do
if !seen.contains(ref) then
seen += ref
cl = cl ++ ref.captureSetOfInfo.elems
prev ne cl
do ()
myClosure = cl
myClosure

def ++ (that: CaptureSet): CaptureSet =
CaptureSet(elems ++ that.elems)
if this.isEmpty then that
else if that.isEmpty then this
else CaptureSet(elems ++ that.elems)

def + (ref: CaptureRef) =
if elems.contains(ref) then this
else CaptureSet(elems + ref)

def intersect (that: CaptureSet): CaptureSet =
CaptureSet(this.elems.intersect(that.elems))

/** {x} <:< this where <:< is subcapturing */
def accountsFor(x: CaptureRef)(using Context) =
Expand All @@ -45,6 +37,15 @@ case class CaptureSet private (elems: CaptureSet.Refs) extends Showable:
def <:< (that: CaptureSet)(using Context): Boolean =
elems.isEmpty || elems.forall(that.accountsFor)

def flatMap(f: CaptureRef => CaptureSet)(using Context): CaptureSet =
(empty /: elems)((cs, ref) => cs ++ f(ref))

def substParams(tl: BindingType, to: List[Type])(using Context) =
flatMap {
case ref: ParamRef if ref.binder eq tl => to(ref.paramNum).captureSet
case ref => ref.singletonCaptureSet
}

override def toString = elems.toString

override def toText(printer: Printer): Text =
Expand Down Expand Up @@ -82,46 +83,26 @@ object CaptureSet:
css.foldLeft(empty)(_ ++ _)

def ofType(tp: Type)(using Context): CaptureSet =
val collect = new TypeAccumulator[Refs]:
var localBinders: SimpleIdentitySet[BindingType] = SimpleIdentitySet.empty
var seenLazyRefs: SimpleIdentitySet[LazyRef] = SimpleIdentitySet.empty
def apply(elems: Refs, tp: Type): Refs = trace(i"capt $elems, $tp", capt, show = true) {
tp match
case tp: NamedType =>
if variance < 0 then elems
else elems ++ tp.captureSet.elems
case tp: ParamRef =>
if variance < 0 || localBinders.contains(tp.binder) then elems
else elems ++ tp.captureSet.elems
case tp: LambdaType =>
localBinders += tp
try apply(elems, tp.resultType)
finally localBinders -= tp
case AndType(tp1, tp2) =>
val elems1 = apply(SimpleIdentitySet.empty, tp1)
val elems2 = apply(SimpleIdentitySet.empty, tp2)
elems ++ elems1.intersect(elems2)
case CapturingType(parent, ref) =>
val elems1 = apply(elems, parent)
if variance >= 0 then elems1 + ref else elems1
case TypeBounds(_, hi) =>
apply(elems, hi)
case tp: ClassInfo =>
elems ++ ofClass(tp, Nil).elems
case tp: LazyRef =>
if seenLazyRefs.contains(tp)
|| tp.evaluating // shapeless gets an assertion error without this test
then elems
else
seenLazyRefs += tp
foldOver(elems, tp)
// case tp: MatchType =>
// val normed = tp.tryNormalize
// if normed.exists then apply(elems, normed) else foldOver(elems, tp)
case _ =>
foldOver(elems, tp)
}

CaptureSet(collect(empty.elems, tp))
def recur(tp: Type): CaptureSet = tp match
case tp: CaptureRef =>
tp.captureSet
case CapturingType(parent, ref) =>
recur(parent) + ref
case AppliedType(tycon, args) =>
val cs = recur(tycon)
tycon.typeParams match
case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args)
case _ => cs
case tp: TypeProxy =>
recur(tp.underlying)
case AndType(tp1, tp2) =>
recur(tp1).intersect(recur(tp2))
case OrType(tp1, tp2) =>
recur(tp1) ++ recur(tp2)
case tp: ClassInfo =>
ofClass(tp, Nil)
case _ =>
empty
recur(tp)
.showing(i"capture set of $tp = $result", capt)

2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,7 @@ class Definitions {
* - the upper bound of a TypeParamRef in the current constraint
*/
def asContextFunctionType(tp: Type)(using Context): Type =
tp.stripTypeVar.dealias match
tp.stripped.dealias match
case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) =>
asContextFunctionType(TypeComparer.bounds(tp1).hiBound)
case tp1 =>
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ class TypeApplications(val self: Type) extends AnyVal {
*/
final def appliedTo(args: List[Type])(using Context): Type = {
record("appliedTo")
val typParams = self.typeParams
val stripped = self.stripTypeVar
val dealiased = stripped.safeDealias
if (args.isEmpty || ctx.erasedTypes) self
Expand Down
27 changes: 24 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3620,7 +3620,7 @@ object Types {
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
case tp: CapturingType =>
val status1 = compute(status, tp.parent, theAcc)
tp.ref match
tp.ref.stripTypeVar match
case tp: TermParamRef if tp.binder eq thisLambdaType => combine(status1, CaptureDeps)
case _ => status1
case _: ThisType | _: BoundType | NoPrefix => status
Expand Down Expand Up @@ -4505,9 +4505,10 @@ object Types {
* @param origin The parameter that's tracked by the type variable.
* @param creatorState The typer state in which the variable was created.
*/
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int) extends CachedProxyType with ValueType {
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int)
extends CachedProxyType, CaptureRef {

private var currentOrigin = initOrigin
private var currentOrigin = initOrigin

def origin: TypeParamRef = currentOrigin

Expand Down Expand Up @@ -4689,6 +4690,26 @@ object Types {
if (inst.exists) inst else origin
}

// Capture ref methods

def canBeTracked(using Context): Boolean = underlying match
case ref: CaptureRef => ref.canBeTracked
case _ => false

override def normalizedRef(using Context): CaptureRef = instanceOpt match
case ref: CaptureRef => ref
case _ => this

override def singletonCaptureSet(using Context) = instanceOpt match
case ref: CaptureRef => ref.singletonCaptureSet
case _ => super.singletonCaptureSet

override def captureSetOfInfo(using Context): CaptureSet = instanceOpt match
case ref: CaptureRef => ref.captureSetOfInfo
case tp => tp.captureSet

// Object members

override def computeHash(bs: Binders): Int = identityHash(bs)
override def equals(that: Any): Boolean = this.eq(that.asInstanceOf[AnyRef])

Expand Down
50 changes: 37 additions & 13 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,29 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
argStr ~ " " ~ arrow(isGiven) ~ " " ~ argText(args.last)
}

def toTextDependentFunction(appType: MethodType): Text =
"("
~ keywordText("erased ").provided(appType.isErasedMethod)
~ paramsText(appType)
~ ") "
~ arrow(appType.isImplicitMethod)
~ " "
~ toText(appType.resultType)
def toTextMethodAsFunction(info: Type): Text = info match
case info: MethodType =>
changePrec(GlobalPrec) {
"("
~ keywordText("erased ").provided(info.isErasedMethod)
~ ( if info.isParamDependent || info.isResultDependent
then paramsText(info)
else argsText(info.paramInfos)
)
~ ") "
~ arrow(info.isImplicitMethod)
~ " "
~ toTextMethodAsFunction(info.resultType)
}
case info: PolyType =>
changePrec(GlobalPrec) {
"["
~ paramsText(info)
~ "] => "
~ toTextMethodAsFunction(info.resultType)
}
case _ =>
toText(info)

def isInfixType(tp: Type): Boolean = tp match
case AppliedType(tycon, args) =>
Expand Down Expand Up @@ -229,8 +244,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
// don't eta contract if the application would be printed specially
toText(tycon)
case tp: RefinedType if defn.isFunctionType(tp) && !printDebug =>
toTextDependentFunction(tp.refinedInfo.asInstanceOf[MethodType])
case tp: RefinedType
if (defn.isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass))
&& !printDebug =>
toTextMethodAsFunction(tp.refinedInfo)
case tp: TypeRef =>
if (tp.symbol.isAnonymousClass && !showUniqueIds)
toText(tp.info)
Expand All @@ -244,6 +261,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
case ErasedValueType(tycon, underlying) =>
"ErasedValueType(" ~ toText(tycon) ~ ", " ~ toText(underlying) ~ ")"
case tp: ClassInfo =>
if tp.cls.derivesFrom(defn.PolyFunctionClass) then
tp.member(nme.apply).info match
case info: PolyType => return toTextMethodAsFunction(info)
case _ =>
toTextParents(tp.parents) ~~ "{...}"
case JavaArrayType(elemtp) =>
toText(elemtp) ~ "[]"
Expand Down Expand Up @@ -506,13 +527,16 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
case RefinedTypeTree(tpt, refines) =>
toTextLocal(tpt) ~ " " ~ blockText(refines)
case AppliedTypeTree(tpt, args) =>
if (tpt.symbol == defn.orType && args.length == 2)
if tpt.symbol == defn.orType && args.length == 2 then
changePrec(OrTypePrec) { toText(args(0)) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(args(1)) } }
else if (tpt.symbol == defn.andType && args.length == 2)
else if tpt.symbol == defn.andType && args.length == 2 then
changePrec(AndTypePrec) { toText(args(0)) ~ " & " ~ atPrec(AndTypePrec + 1) { toText(args(1)) } }
else if tpt.symbol == defn.Predef_retainsType && args.length == 2 then
changePrec(InfixPrec) { toText(args(0)) ~ " retains " ~ toText(args(1)) }
else if defn.isFunctionClass(tpt.symbol)
&& tpt.isInstanceOf[TypeTree] && tree.hasType && !printDebug
then changePrec(GlobalPrec) { toText(tree.typeOpt) }
then
changePrec(GlobalPrec) { toText(tree.typeOpt) }
else args match
case arg :: _ if arg.isTerm =>
toTextLocal(tpt) ~ "(" ~ Text(args.map(argText), ", ") ~ ")"
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,14 @@ class TreeChecker extends Phase with SymTransformer {
val tpe = tree.typeOpt

// Polymorphic apply methods stay structural until Erasure
val isPolyFunctionApply = (tree.name eq nme.apply) && (tree.qualifier.typeOpt <:< defn.PolyFunctionType)
val isPolyFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass)
// Outer selects are pickled specially so don't require a symbol
val isOuterSelect = tree.name.is(OuterSelectName)
val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name)
if !(tree.isType || isPolyFunctionApply || isOuterSelect || isPrimitiveArrayOp) then
val denot = tree.denot
assert(denot.exists, i"Selection $tree with type $tpe does not have a denotation")
assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol")
assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol, ${tree.qualifier.typeOpt}")

val sym = tree.symbol
val symIsFixed = tpe match {
Expand Down
Loading