Skip to content

Commit 97008df

Browse files
committed
refactor box adaptation for override checking
- make OverridingPairsChecker extensible, with the method for comparing subtype being abstract, and extend it with the box-adapt-enabled subtyping in CC - delay override checking after checking captures - delegate captured references of methods to the class
1 parent 31b3c11 commit 97008df

File tree

7 files changed

+111
-52
lines changed

7 files changed

+111
-52
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import config.Printers.{capt, recheckr}
1010
import config.{Config, Feature}
1111
import ast.{tpd, untpd, Trees}
1212
import Trees.*
13-
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents}
13+
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker}
1414
import typer.Checking.{checkBounds, checkAppliedTypesIn}
1515
import util.{SimpleIdentitySet, EqHashMap, SrcPos}
1616
import transform.SymUtils.*
@@ -824,37 +824,44 @@ class CheckCaptures extends Recheck, SymTransformer:
824824
* But maybe we can then elide the check during the RefChecks phase under captureChecking?
825825
*/
826826
def checkOverrides = new TreeTraverser:
827-
/** Check subtype with box adaptation.
828-
* This function is passed to RefChecks to check the compatibility of overriding pairs.
829-
* @param sym symbol of the field definition that is being checked
830-
*/
831-
def checkSubtype(sym: Symbol)(tree: Tree, actual: Type, expected: Type)(using Context): Boolean =
832-
val expected1 = alignDependentFunction(addOuterRefs(expected, actual), actual.stripCapturing)
833-
val actual1 =
834-
val saved = curEnv
835-
try
836-
curEnv = Env(sym, nestedInOwner = false, CaptureSet.Var(), isBoxed = false, outer0 = curEnv)
837-
adaptBoxed(actual, expected1, tree.srcPos, alwaysConst = true) match
838-
case tp @ CapturingType(parent, refs) =>
839-
CapturingType(parent, CaptureSet(refs.elems))
840-
case tp => tp
841-
finally curEnv = saved
842-
actual1 frozen_<:< expected1
827+
class OverridingPairsCheckerCC(clazz: ClassSymbol, self: Type, srcPos: SrcPos)(using Context) extends OverridingPairsChecker(clazz, self) {
828+
/** Check subtype with box adaptation.
829+
* This function is passed to RefChecks to check the compatibility of overriding pairs.
830+
* @param sym symbol of the field definition that is being checked
831+
*/
832+
override def checkSubType(actual: Type, expected: Type)(using Context): Boolean =
833+
val expected1 = alignDependentFunction(addOuterRefs(expected, actual), actual.stripCapturing)
834+
val actual1 =
835+
val saved = curEnv
836+
try
837+
curEnv = Env(clazz, nestedInOwner = true, capturedVars(clazz), isBoxed = false, outer0 = curEnv)
838+
val adapted = adaptBoxed(actual, expected1, srcPos, alwaysConst = true)
839+
actual match
840+
case _: MethodType =>
841+
// We remove the capture set resulted from box adaptation for method types,
842+
// since class methods are always treated as pure, and their captured variables
843+
// are charged to the capture set of the class (which is already done during
844+
// box adaptation).
845+
adapted.stripCapturing
846+
case _ => adapted
847+
finally curEnv = saved
848+
actual1 frozen_<:< expected1
849+
}
843850

844851
def traverse(t: Tree)(using Context) =
845852
t match
846853
case t: Template =>
847-
checkAllOverrides(ctx.owner.asClass, isSubType = sym => (tp1, tp2) => checkSubtype(sym)(t, tp1, tp2))
854+
checkAllOverrides(ctx.owner.asClass, OverridingPairsCheckerCC(_, _, t))
848855
case _ =>
849856
traverseChildren(t)
850857

851858
override def checkUnit(unit: CompilationUnit)(using Context): Unit =
852-
checkOverrides.traverse(unit.tpdTree)
853859
Setup(preRecheckPhase, thisPhase, recheckDef)
854860
.traverse(ctx.compilationUnit.tpdTree)
855861
//println(i"SETUP:\n${Recheck.addRecheckedTypes.transform(ctx.compilationUnit.tpdTree)}")
856862
withCaptureSetsExplained {
857863
super.checkUnit(unit)
864+
checkOverrides.traverse(unit.tpdTree)
858865
checkSelfTypes(unit.tpdTree)
859866
postCheck(unit.tpdTree)
860867
if ctx.settings.YccDebug.value then

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ object Types {
10751075
* @param isSubType a function used for checking subtype relationships.
10761076
*/
10771077
final def overrides(that: Type, relaxedCheck: Boolean, matchLoosely: => Boolean, checkClassInfo: Boolean = true,
1078-
isSubType: Context ?=> (Type, Type) => Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean = {
1078+
isSubType: (Type, Type) => Context ?=> Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean = {
10791079
val overrideCtx = if relaxedCheck then ctx.relaxedOverrideContext else ctx
10801080
inContext(overrideCtx) {
10811081
!checkClassInfo && this.isInstanceOf[ClassInfo]

compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ object OverridingPairs:
206206
* between term fields.
207207
*/
208208
def isOverridingPair(member: Symbol, memberTp: Type, other: Symbol, otherTp: Type, fallBack: => Boolean = false,
209-
isSubType: Context ?=> (Type, Type) => Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean =
209+
isSubType: (Type, Type) => Context ?=> Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean =
210210
if member.isType then // intersection of bounds to refined types must be nonempty
211211
memberTp.bounds.hi.hasSameKindAs(otherTp.bounds.hi)
212212
&& (

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,14 @@ object RefChecks {
237237
&& inLinearizationOrder(sym1, sym2, parent)
238238
&& !sym2.is(AbsOverride)
239239

240-
def checkAll(checkOverride: (Symbol, Symbol) => Unit) =
240+
// Checks the subtype relationship tp1 <:< tp2.
241+
// It is passed to the `checkOverride` operation in `checkAll`, to be used for
242+
// compatibility checking.
243+
def checkSubType(tp1: Type, tp2: Type)(using Context): Boolean = tp1 frozen_<:< tp2
244+
245+
def checkAll(checkOverride: ((Type, Type) => Context ?=> Boolean, Symbol, Symbol) => Unit) =
241246
while hasNext do
242-
checkOverride(overriding, overridden)
247+
checkOverride(checkSubType, overriding, overridden)
243248
next()
244249

245250
// The OverridingPairs cursor does assume that concrete overrides abstract
@@ -253,7 +258,7 @@ object RefChecks {
253258
if dcl.is(Deferred) then
254259
for other <- dcl.allOverriddenSymbols do
255260
if !other.is(Deferred) then
256-
checkOverride(dcl, other)
261+
checkOverride(checkSubType, dcl, other)
257262
end checkAll
258263
end OverridingPairsChecker
259264

@@ -291,12 +296,10 @@ object RefChecks {
291296
* TODO This still needs to be cleaned up; the current version is a straight port of what was there
292297
* before, but it looks too complicated and method bodies are far too large.
293298
*
294-
* @param isSubType A function used for checking the subtype relationship between
295-
* two types `tp1` and `tp2` when checking the compatibility
296-
* between overriding pairs, with possible adaptations applied
297-
* (e.g. box adaptation in capture checking).
299+
* @param makeOverridePairsChecker A function for creating a OverridePairsChecker instance
300+
* from the class symbol and the self type
298301
*/
299-
def checkAllOverrides(clazz: ClassSymbol, isSubType: Context ?=> Symbol => (Type, Type) => Boolean = _ => (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Unit = {
302+
def checkAllOverrides(clazz: ClassSymbol, makeOverridingPairsChecker: ((ClassSymbol, Type) => Context ?=> OverridingPairsChecker) | Null = null)(using Context): Unit = {
300303
val self = clazz.thisType
301304
val upwardsSelf = upwardsThisType(clazz)
302305
var hasErrors = false
@@ -327,10 +330,17 @@ object RefChecks {
327330
def infoStringWithLocation(sym: Symbol) =
328331
err.infoString(sym, self, showLocation = true)
329332

333+
def isInheritedAccessor(mbr: Symbol, other: Symbol): Boolean =
334+
mbr.is(ParamAccessor)
335+
&& {
336+
val next = ParamForwarding.inheritedAccessor(mbr)
337+
next == other || isInheritedAccessor(next, other)
338+
}
339+
330340
/* Check that all conditions for overriding `other` by `member`
331-
* of class `clazz` are met.
332-
*/
333-
def checkOverride(member: Symbol, other: Symbol): Unit =
341+
* of class `clazz` are met.
342+
*/
343+
def checkOverride(checkSubType: (Type, Type) => Context ?=> Boolean, member: Symbol, other: Symbol): Unit =
334344
def memberTp(self: Type) =
335345
if (member.isClass) TypeAlias(member.typeRef.EtaExpand(member.typeParams))
336346
else self.memberInfo(member)
@@ -350,7 +360,7 @@ object RefChecks {
350360
fallBack = warnOnMigration(
351361
overrideErrorMsg("no longer has compatible type"),
352362
(if (member.owner == clazz) member else clazz).srcPos, version = `3.0`),
353-
isSubType = isSubType(member))
363+
isSubType = checkSubType)
354364
catch case ex: MissingType =>
355365
// can happen when called with upwardsSelf as qualifier of memberTp and otherTp,
356366
// because in that case we might access types that are not members of the qualifier.
@@ -506,7 +516,7 @@ object RefChecks {
506516
else if (member.is(ModuleVal) && !other.isRealMethod && !other.isOneOf(DeferredOrLazy))
507517
overrideError("may not override a concrete non-lazy value")
508518
else if (member.is(Lazy, butNot = Module) && !other.isRealMethod && !other.is(Lazy) &&
509-
!warnOnMigration(overrideErrorMsg("may not override a non-lazy value"), member.srcPos, version = `3.0`))
519+
!warnOnMigration(overrideErrorMsg("may not override a non-lazy value"), member.srcPos, version = `3.0`))
510520
overrideError("may not override a non-lazy value")
511521
else if (other.is(Lazy) && !other.isRealMethod && !member.is(Lazy))
512522
overrideError("must be declared lazy to override a lazy value")
@@ -539,14 +549,8 @@ object RefChecks {
539549
overrideDeprecation("", member, other, "removed or renamed")
540550
end checkOverride
541551

542-
def isInheritedAccessor(mbr: Symbol, other: Symbol): Boolean =
543-
mbr.is(ParamAccessor)
544-
&& {
545-
val next = ParamForwarding.inheritedAccessor(mbr)
546-
next == other || isInheritedAccessor(next, other)
547-
}
548-
549-
OverridingPairsChecker(clazz, self).checkAll(checkOverride)
552+
val checker = if makeOverridingPairsChecker == null then OverridingPairsChecker(clazz, self) else makeOverridingPairsChecker(clazz, self)
553+
checker.checkAll(checkOverride)
550554
printMixinOverrideErrors()
551555

552556
// Verifying a concrete class has nothing unimplemented.

tests/neg-custom-args/captures/lazylist.check

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
-- [E164] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ----------------------------------------
2-
22 | def tail: {*} LazyList[Nothing] = ??? // error overriding
3-
| ^
4-
| error overriding method tail in class LazyList of type -> lazylists.LazyList[Nothing];
5-
| method tail of type -> {*} lazylists.LazyList[Nothing] has incompatible type
6-
|
7-
| longer explanation available when compiling with `-explain`
81
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:17:15 -------------------------------------
92
17 | def tail = xs() // error
103
| ^^^^
@@ -40,3 +33,10 @@
4033
| Required: {cap1, ref3, cap3} lazylists.LazyList[Int]
4134
|
4235
| longer explanation available when compiling with `-explain`
36+
-- [E164] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ----------------------------------------
37+
22 | def tail: {*} LazyList[Nothing] = ??? // error overriding
38+
| ^
39+
| error overriding method tail in class LazyList of type -> lazylists.LazyList[Nothing];
40+
| method tail of type -> {*} lazylists.LazyList[Nothing] has incompatible type
41+
|
42+
| longer explanation available when compiling with `-explain`
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import language.experimental.captureChecking
2+
3+
class IO
4+
class C
5+
6+
object Test1 {
7+
abstract class A[X] { this: {} A[X] =>
8+
def foo(x: X): X
9+
}
10+
11+
def test(io: {*} IO) = {
12+
class B extends A[{io} C] { // X =:= {io} C // error
13+
override def foo(x: {io} C): {io} C = ???
14+
}
15+
}
16+
}
17+
18+
def Test2(io: {*} IO, fs: {io} IO, ct: {*} IO) = {
19+
abstract class A[X] { this: {io} A[X] =>
20+
def foo(x: X): X
21+
}
22+
23+
class B1 extends A[{io} C] {
24+
override def foo(x: {io} C): {io} C = ???
25+
}
26+
27+
class B2 extends A[{ct} C] { // error
28+
override def foo(x: {ct} C): {ct} C = ???
29+
}
30+
31+
class B3 extends A[{fs} C] {
32+
override def foo(x: {fs} C): {fs} C = ???
33+
}
34+
}
35+
36+
def Test3(io: {*} IO, ct: {*} IO) = {
37+
abstract class A[X] { this: {*} A[X] =>
38+
def foo(x: X): X
39+
}
40+
41+
class B1 extends A[{io} C] {
42+
override def foo(x: {io} C): {io} C = ???
43+
}
44+
45+
class B2 extends A[{io, ct} C] {
46+
override def foo(x: {io, ct} C): {io, ct} C = ???
47+
}
48+
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import language.experimental.captureChecking
22

3-
abstract class A[X] {
3+
abstract class A[X] { this: ({} A[X]) =>
44
def foo(x: X): X
55
}
66

77
class IO
88
class C
99

1010
def test(io: {*} IO) = {
11-
class B extends A[{io} C] { // X =:= {io} C
12-
override def foo(x: {io} C): {io} C = ??? // error
11+
class B extends A[{io} C] { // X =:= {io} C // error
12+
override def foo(x: {io} C): {io} C = ???
1313
}
1414
}

0 commit comments

Comments
 (0)