Skip to content

Commit f96b29e

Browse files
authored
Apply box adaptation when checking overrides (#16479)
In this PR, we allow box adaptation when checking the compatibility of overriding pairs in capture checking phase. For example, it allows the following code to compile: ```scala class IO abstract class A[X] { def foo(x: Unit): X def bar(op: X => Int): Int } class C def test(io: {*} IO) = { class B extends A[{io} C] { // X =:= {io} C def foo(x: Unit): {io} C = ??? def bar(op: ({io} C) => Int): Int = 0 } } ``` The `foo` and `bar` in `B` are both valid overrides, since the types in `B` and `A[{io} C]` is compatible with box adaptation.
2 parents c4d63cc + 9dd4f52 commit f96b29e

21 files changed

+262
-76
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ extension (tp: Type)
166166
case CapturingType(_, _) => true
167167
case _ => false
168168

169+
def isEventuallyCapturingType(using Context): Boolean =
170+
tp match
171+
case EventuallyCapturingType(_, _) => true
172+
case _ => false
173+
169174
/** Is type known to be always pure by its class structure,
170175
* so that adding a capture set to it would not make sense?
171176
*/

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ object CapturingType:
4848
EventuallyCapturingType.unapply(tp)
4949
else None
5050

51+
/** Check whether a type is uncachable when computing `baseType`.
52+
* - Avoid caching all the types during the setup phase, since at that point
53+
* the capture set variables are not fully installed yet.
54+
* - Avoid caching capturing types when IgnoreCaptures mode is set, since the
55+
* capture sets may be thrown away in the computed base type.
56+
*/
57+
def isUncachable(tp: Type)(using Context): Boolean =
58+
ctx.phase == Phases.checkCapturesPhase &&
59+
(Setup.isDuringSetup || ctx.mode.is(Mode.IgnoreCaptures) && tp.isEventuallyCapturingType)
60+
5161
end CapturingType
5262

5363
/** An extractor for types that will be capturing types at phase CheckCaptures. Also

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import config.Printers.{capt, recheckr}
1010
import config.{Config, Feature}
1111
import ast.{tpd, untpd, Trees}
1212
import Trees.*
13-
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents}
13+
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker}
1414
import typer.Checking.{checkBounds, checkAppliedTypesIn}
1515
import util.{SimpleIdentitySet, EqHashMap, SrcPos}
1616
import transform.SymUtils.*
@@ -141,25 +141,12 @@ class CheckCaptures extends Recheck, SymTransformer:
141141

142142
override def run(using Context): Unit =
143143
if Feature.ccEnabled then
144-
checkOverrides.traverse(ctx.compilationUnit.tpdTree)
145144
super.run
146145

147146
override def transformSym(sym: SymDenotation)(using Context): SymDenotation =
148147
if Synthetics.needsTransform(sym) then Synthetics.transformFromCC(sym)
149148
else super.transformSym(sym)
150149

151-
/** Check overrides again, taking capture sets into account.
152-
* TODO: Can we avoid doing overrides checks twice?
153-
* We need to do them here since only at this phase CaptureTypes are relevant
154-
* But maybe we can then elide the check during the RefChecks phase under captureChecking?
155-
*/
156-
def checkOverrides = new TreeTraverser:
157-
def traverse(t: Tree)(using Context) =
158-
t match
159-
case t: Template => checkAllOverrides(ctx.owner.asClass)
160-
case _ =>
161-
traverseChildren(t)
162-
163150
class CaptureChecker(ictx: Context) extends Rechecker(ictx):
164151
import ast.tpd.*
165152

@@ -668,8 +655,11 @@ class CheckCaptures extends Recheck, SymTransformer:
668655
case _ =>
669656
expected
670657

671-
/** Adapt `actual` type to `expected` type by inserting boxing and unboxing conversions */
672-
def adaptBoxed(actual: Type, expected: Type, pos: SrcPos)(using Context): Type =
658+
/** Adapt `actual` type to `expected` type by inserting boxing and unboxing conversions
659+
*
660+
* @param alwaysConst always make capture set variables constant after adaptation
661+
*/
662+
def adaptBoxed(actual: Type, expected: Type, pos: SrcPos, alwaysConst: Boolean = false)(using Context): Type =
673663

674664
/** Adapt function type `actual`, which is `aargs -> ares` (possibly with dependencies)
675665
* to `expected` type.
@@ -746,7 +736,8 @@ class CheckCaptures extends Recheck, SymTransformer:
746736
else
747737
((parent, cs, tp.isBoxed), reconstruct)
748738
case actual =>
749-
((actual, CaptureSet(), false), reconstruct)
739+
val res = if tp.isFromJavaObject then tp else actual
740+
((res, CaptureSet(), false), reconstruct)
750741

751742
def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
752743
if expected.isInstanceOf[WildcardType] then actual
@@ -806,9 +797,9 @@ class CheckCaptures extends Recheck, SymTransformer:
806797
}
807798
if !insertBox then // unboxing
808799
markFree(criticalSet, pos)
809-
recon(CapturingType(parent1, cs1, !actualIsBoxed))
800+
recon(CapturingType(parent1, if alwaysConst then CaptureSet(cs1.elems) else cs1, !actualIsBoxed))
810801
else
811-
recon(CapturingType(parent1, cs1, actualIsBoxed))
802+
recon(CapturingType(parent1, if alwaysConst then CaptureSet(cs1.elems) else cs1, actualIsBoxed))
812803
}
813804

814805
var actualw = actual.widenDealias
@@ -827,12 +818,49 @@ class CheckCaptures extends Recheck, SymTransformer:
827818
else actual
828819
end adaptBoxed
829820

821+
/** Check overrides again, taking capture sets into account.
822+
* TODO: Can we avoid doing overrides checks twice?
823+
* We need to do them here since only at this phase CaptureTypes are relevant
824+
* But maybe we can then elide the check during the RefChecks phase under captureChecking?
825+
*/
826+
def checkOverrides = new TreeTraverser:
827+
class OverridingPairsCheckerCC(clazz: ClassSymbol, self: Type, srcPos: SrcPos)(using Context) extends OverridingPairsChecker(clazz, self) {
828+
/** Check subtype with box adaptation.
829+
* This function is passed to RefChecks to check the compatibility of overriding pairs.
830+
* @param sym symbol of the field definition that is being checked
831+
*/
832+
override def checkSubType(actual: Type, expected: Type)(using Context): Boolean =
833+
val expected1 = alignDependentFunction(addOuterRefs(expected, actual), actual.stripCapturing)
834+
val actual1 =
835+
val saved = curEnv
836+
try
837+
curEnv = Env(clazz, nestedInOwner = true, capturedVars(clazz), isBoxed = false, outer0 = curEnv)
838+
val adapted = adaptBoxed(actual, expected1, srcPos, alwaysConst = true)
839+
actual match
840+
case _: MethodType =>
841+
// We remove the capture set resulted from box adaptation for method types,
842+
// since class methods are always treated as pure, and their captured variables
843+
// are charged to the capture set of the class (which is already done during
844+
// box adaptation).
845+
adapted.stripCapturing
846+
case _ => adapted
847+
finally curEnv = saved
848+
actual1 frozen_<:< expected1
849+
}
850+
851+
def traverse(t: Tree)(using Context) =
852+
t match
853+
case t: Template =>
854+
checkAllOverrides(ctx.owner.asClass, OverridingPairsCheckerCC(_, _, t))
855+
case _ =>
856+
traverseChildren(t)
857+
830858
override def checkUnit(unit: CompilationUnit)(using Context): Unit =
831-
Setup(preRecheckPhase, thisPhase, recheckDef)
832-
.traverse(ctx.compilationUnit.tpdTree)
859+
Setup(preRecheckPhase, thisPhase, recheckDef)(ctx.compilationUnit.tpdTree)
833860
//println(i"SETUP:\n${Recheck.addRecheckedTypes.transform(ctx.compilationUnit.tpdTree)}")
834861
withCaptureSetsExplained {
835862
super.checkUnit(unit)
863+
checkOverrides.traverse(unit.tpdTree)
836864
checkSelfTypes(unit.tpdTree)
837865
postCheck(unit.tpdTree)
838866
if ctx.settings.YccDebug.value then

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import ast.tpd
1111
import transform.Recheck.*
1212
import CaptureSet.IdentityCaptRefMap
1313
import Synthetics.isExcluded
14+
import util.Property
1415

1516
/** A tree traverser that prepares a compilation unit to be capture checked.
1617
* It does the following:
@@ -484,4 +485,14 @@ extends tpd.TreeTraverser:
484485
capt.println(i"update info of ${tree.symbol} from $info to $newInfo")
485486
case _ =>
486487
end traverse
488+
489+
def apply(tree: Tree)(using Context): Unit =
490+
traverse(tree)(using ctx.withProperty(Setup.IsDuringSetupKey, Some(())))
487491
end Setup
492+
493+
object Setup:
494+
val IsDuringSetupKey = new Property.Key[Unit]
495+
496+
def isDuringSetup(using Context): Boolean =
497+
ctx.property(IsDuringSetupKey).isDefined
498+

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import config.Config
2424
import reporting._
2525
import collection.mutable
2626
import transform.TypeUtils._
27-
import cc.{CapturingType, derivedCapturingType}
27+
import cc.{CapturingType, derivedCapturingType, Setup, EventuallyCapturingType, isEventuallyCapturingType}
2828

2929
import scala.annotation.internal.sharable
3030

@@ -2147,7 +2147,7 @@ object SymDenotations {
21472147
Stats.record("basetype cache entries")
21482148
if (!baseTp.exists) Stats.record("basetype cache NoTypes")
21492149
}
2150-
if (!tp.isProvisional)
2150+
if (!tp.isProvisional && !CapturingType.isUncachable(tp))
21512151
btrCache(tp) = baseTp
21522152
else
21532153
btrCache.remove(tp) // Remove any potential sentinel value
@@ -2161,8 +2161,9 @@ object SymDenotations {
21612161
def recur(tp: Type): Type = try {
21622162
tp match {
21632163
case tp: CachedType =>
2164-
val baseTp = btrCache.lookup(tp)
2165-
if (baseTp != null) return ensureAcyclic(baseTp)
2164+
val baseTp: Type | Null = btrCache.lookup(tp)
2165+
if (baseTp != null)
2166+
return ensureAcyclic(baseTp)
21662167
case _ =>
21672168
}
21682169
if (Stats.monitored) {

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,12 +1071,15 @@ object Types {
10711071
* @param relaxedCheck if true type `Null` becomes a subtype of non-primitive value types in TypeComparer.
10721072
* @param matchLoosely if true the types `=> T` and `()T` are seen as overriding each other.
10731073
* @param checkClassInfo if true we check that ClassInfos are within bounds of abstract types
1074+
*
1075+
* @param isSubType a function used for checking subtype relationships.
10741076
*/
1075-
final def overrides(that: Type, relaxedCheck: Boolean, matchLoosely: => Boolean, checkClassInfo: Boolean = true)(using Context): Boolean = {
1077+
final def overrides(that: Type, relaxedCheck: Boolean, matchLoosely: => Boolean, checkClassInfo: Boolean = true,
1078+
isSubType: (Type, Type) => Context ?=> Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean = {
10761079
val overrideCtx = if relaxedCheck then ctx.relaxedOverrideContext else ctx
10771080
inContext(overrideCtx) {
10781081
!checkClassInfo && this.isInstanceOf[ClassInfo]
1079-
|| (this.widenExpr frozen_<:< that.widenExpr)
1082+
|| isSubType(this.widenExpr, that.widenExpr)
10801083
|| matchLoosely && {
10811084
val this1 = this.widenNullaryMethod
10821085
val that1 = that.widenNullaryMethod

compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,13 @@ object OverridingPairs:
200200
/** Let `member` and `other` be members of some common class C with types
201201
* `memberTp` and `otherTp` in C. Are the two symbols considered an overriding
202202
* pair in C? We assume that names already match so we test only the types here.
203-
* @param fallBack A function called if the initial test is false and
204-
* `member` and `other` are term symbols.
203+
* @param fallBack A function called if the initial test is false and
204+
* `member` and `other` are term symbols.
205+
* @param isSubType A function to be used for checking subtype relationships
206+
* between term fields.
205207
*/
206-
def isOverridingPair(member: Symbol, memberTp: Type, other: Symbol, otherTp: Type, fallBack: => Boolean = false)(using Context): Boolean =
208+
def isOverridingPair(member: Symbol, memberTp: Type, other: Symbol, otherTp: Type, fallBack: => Boolean = false,
209+
isSubType: (Type, Type) => Context ?=> Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean =
207210
if member.isType then // intersection of bounds to refined types must be nonempty
208211
memberTp.bounds.hi.hasSameKindAs(otherTp.bounds.hi)
209212
&& (
@@ -222,6 +225,6 @@ object OverridingPairs:
222225
val relaxedOverriding = ctx.explicitNulls && (member.is(JavaDefined) || other.is(JavaDefined))
223226
member.name.is(DefaultGetterName) // default getters are not checked for compatibility
224227
|| memberTp.overrides(otherTp, relaxedOverriding,
225-
member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack)
228+
member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack, isSubType = isSubType)
226229

227230
end OverridingPairs

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,16 @@ object RefChecks {
234234
&& inLinearizationOrder(sym1, sym2, parent)
235235
&& !sym2.is(AbsOverride)
236236

237-
def checkAll(checkOverride: (Symbol, Symbol) => Unit) =
237+
// Checks the subtype relationship tp1 <:< tp2.
238+
// It is passed to the `checkOverride` operation in `checkAll`, to be used for
239+
// compatibility checking.
240+
def checkSubType(tp1: Type, tp2: Type)(using Context): Boolean = tp1 frozen_<:< tp2
241+
242+
private val subtypeChecker: (Type, Type) => Context ?=> Boolean = this.checkSubType
243+
244+
def checkAll(checkOverride: ((Type, Type) => Context ?=> Boolean, Symbol, Symbol) => Unit) =
238245
while hasNext do
239-
checkOverride(overriding, overridden)
246+
checkOverride(subtypeChecker, overriding, overridden)
240247
next()
241248

242249
// The OverridingPairs cursor does assume that concrete overrides abstract
@@ -250,7 +257,7 @@ object RefChecks {
250257
if dcl.is(Deferred) then
251258
for other <- dcl.allOverriddenSymbols do
252259
if !other.is(Deferred) then
253-
checkOverride(dcl, other)
260+
checkOverride(checkSubType, dcl, other)
254261
end checkAll
255262
end OverridingPairsChecker
256263

@@ -287,8 +294,11 @@ object RefChecks {
287294
* TODO check that classes are not overridden
288295
* TODO This still needs to be cleaned up; the current version is a straight port of what was there
289296
* before, but it looks too complicated and method bodies are far too large.
297+
*
298+
* @param makeOverridePairsChecker A function for creating a OverridePairsChecker instance
299+
* from the class symbol and the self type
290300
*/
291-
def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = {
301+
def checkAllOverrides(clazz: ClassSymbol, makeOverridingPairsChecker: ((ClassSymbol, Type) => Context ?=> OverridingPairsChecker) | Null = null)(using Context): Unit = {
292302
val self = clazz.thisType
293303
val upwardsSelf = upwardsThisType(clazz)
294304
var hasErrors = false
@@ -319,10 +329,17 @@ object RefChecks {
319329
def infoStringWithLocation(sym: Symbol) =
320330
err.infoString(sym, self, showLocation = true)
321331

332+
def isInheritedAccessor(mbr: Symbol, other: Symbol): Boolean =
333+
mbr.is(ParamAccessor)
334+
&& {
335+
val next = ParamForwarding.inheritedAccessor(mbr)
336+
next == other || isInheritedAccessor(next, other)
337+
}
338+
322339
/* Check that all conditions for overriding `other` by `member`
323-
* of class `clazz` are met.
324-
*/
325-
def checkOverride(member: Symbol, other: Symbol): Unit =
340+
* of class `clazz` are met.
341+
*/
342+
def checkOverride(checkSubType: (Type, Type) => Context ?=> Boolean, member: Symbol, other: Symbol): Unit =
326343
def memberTp(self: Type) =
327344
if (member.isClass) TypeAlias(member.typeRef.EtaExpand(member.typeParams))
328345
else self.memberInfo(member)
@@ -341,7 +358,8 @@ object RefChecks {
341358
isOverridingPair(member, memberTp, other, otherTp,
342359
fallBack = warnOnMigration(
343360
overrideErrorMsg("no longer has compatible type"),
344-
(if (member.owner == clazz) member else clazz).srcPos, version = `3.0`))
361+
(if (member.owner == clazz) member else clazz).srcPos, version = `3.0`),
362+
isSubType = checkSubType)
345363
catch case ex: MissingType =>
346364
// can happen when called with upwardsSelf as qualifier of memberTp and otherTp,
347365
// because in that case we might access types that are not members of the qualifier.
@@ -353,7 +371,16 @@ object RefChecks {
353371
* Type members are always assumed to match.
354372
*/
355373
def trueMatch: Boolean =
356-
member.isType || memberTp(self).matches(otherTp(self))
374+
member.isType || withMode(Mode.IgnoreCaptures) {
375+
// `matches` does not perform box adaptation so the result here would be
376+
// spurious during capture checking.
377+
//
378+
// Instead of parameterizing `matches` with the function for subtype checking
379+
// with box adaptation, we simply ignore capture annotations here.
380+
// This should be safe since the compatibility under box adaptation is already
381+
// checked.
382+
memberTp(self).matches(otherTp(self))
383+
}
357384

358385
def emitOverrideError(fullmsg: Message) =
359386
if (!(hasErrors && member.is(Synthetic) && member.is(Module))) {
@@ -488,7 +515,7 @@ object RefChecks {
488515
else if (member.is(ModuleVal) && !other.isRealMethod && !other.isOneOf(DeferredOrLazy))
489516
overrideError("may not override a concrete non-lazy value")
490517
else if (member.is(Lazy, butNot = Module) && !other.isRealMethod && !other.is(Lazy) &&
491-
!warnOnMigration(overrideErrorMsg("may not override a non-lazy value"), member.srcPos, version = `3.0`))
518+
!warnOnMigration(overrideErrorMsg("may not override a non-lazy value"), member.srcPos, version = `3.0`))
492519
overrideError("may not override a non-lazy value")
493520
else if (other.is(Lazy) && !other.isRealMethod && !member.is(Lazy))
494521
overrideError("must be declared lazy to override a lazy value")
@@ -521,14 +548,8 @@ object RefChecks {
521548
overrideDeprecation("", member, other, "removed or renamed")
522549
end checkOverride
523550

524-
def isInheritedAccessor(mbr: Symbol, other: Symbol): Boolean =
525-
mbr.is(ParamAccessor)
526-
&& {
527-
val next = ParamForwarding.inheritedAccessor(mbr)
528-
next == other || isInheritedAccessor(next, other)
529-
}
530-
531-
OverridingPairsChecker(clazz, self).checkAll(checkOverride)
551+
val checker = if makeOverridingPairsChecker == null then OverridingPairsChecker(clazz, self) else makeOverridingPairsChecker(clazz, self)
552+
checker.checkAll(checkOverride)
532553
printMixinOverrideErrors()
533554

534555
// Verifying a concrete class has nothing unimplemented.
@@ -572,7 +593,7 @@ object RefChecks {
572593
clazz.nonPrivateMembersNamed(mbr.name)
573594
.filterWithPredicate(
574595
impl => isConcrete(impl.symbol)
575-
&& mbrDenot.matchesLoosely(impl, alwaysCompareTypes = true))
596+
&& withMode(Mode.IgnoreCaptures)(mbrDenot.matchesLoosely(impl, alwaysCompareTypes = true)))
576597
.exists
577598

578599
/** The term symbols in this class and its baseclasses that are
@@ -719,7 +740,7 @@ object RefChecks {
719740
def checkNoAbstractDecls(bc: Symbol): Unit = {
720741
for (decl <- bc.info.decls)
721742
if (decl.is(Deferred)) {
722-
val impl = decl.matchingMember(clazz.thisType)
743+
val impl = withMode(Mode.IgnoreCaptures)(decl.matchingMember(clazz.thisType))
723744
if (impl == NoSymbol || decl.owner.isSubClass(impl.owner))
724745
&& !ignoreDeferred(decl)
725746
then

0 commit comments

Comments
 (0)