Skip to content

Commit 74bc575

Browse files
committed
Allow class parents to be refined types.
Refinements of a class parent are added as synthetic members to the inheriting class.
1 parent 55c2002 commit 74bc575

12 files changed

+181
-80
lines changed

compiler/src/dotty/tools/dotc/core/NamerOps.scala

+21
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package core
55
import Contexts.*, Symbols.*, Types.*, Flags.*, Scopes.*, Decorators.*, Names.*, NameOps.*
66
import SymDenotations.{LazyType, SymDenotation}, StdNames.nme
77
import TypeApplications.EtaExpansion
8+
import collection.mutable
89

910
/** Operations that are shared between Namer and TreeUnpickler */
1011
object NamerOps:
@@ -18,6 +19,26 @@ object NamerOps:
1819
case TypeSymbols(tparams) :: _ => ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef))
1920
case _ => ctor.owner.typeRef
2021

22+
/** Split dependent class refinements off parent type. Add them to `refinements`,
23+
* unless it is null.
24+
*/
25+
extension (tp: Type)
26+
def separateRefinements(cls: ClassSymbol, refinements: mutable.LinkedHashMap[Name, Type] | Null)(using Context): Type =
27+
tp match
28+
case RefinedType(tp1, rname, rinfo) =>
29+
try tp1.separateRefinements(cls, refinements)
30+
finally
31+
if refinements != null then
32+
refinements(rname) = refinements.get(rname) match
33+
case Some(tp) => tp & rinfo
34+
case None => rinfo
35+
case tp @ AnnotatedType(tp1, ann) =>
36+
tp.derivedAnnotatedType(tp1.separateRefinements(cls, refinements), ann)
37+
case tp: RecType =>
38+
tp.parent.substRecThis(tp, cls.thisType).separateRefinements(cls, refinements)
39+
case tp =>
40+
tp
41+
2142
/** If isConstructor, make sure it has at least one non-implicit parameter list
2243
* This is done by adding a () in front of a leading old style implicit parameter,
2344
* or by adding a () as last -- or only -- parameter list if the constructor has

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ class TreeUnpickler(reader: TastyReader,
10431043
}
10441044
val parentReader = fork
10451045
val parents = readParents(withArgs = false)(using parentCtx)
1046-
val parentTypes = parents.map(_.tpe.dealias)
1046+
val parentTypes = parents.map(_.tpe.dealiasKeepAnnots.separateRefinements(cls, null))
10471047
val self =
10481048
if (nextByte == SELFDEF) {
10491049
readByte()

compiler/src/dotty/tools/dotc/transform/init/Util.scala

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ object Util:
2020

2121
def typeRefOf(tp: Type)(using Context): TypeRef = tp.dealias.typeConstructor match
2222
case tref: TypeRef => tref
23+
case RefinedType(parent, _, _) => typeRefOf(parent)
2324
case hklambda: HKTypeLambda => typeRefOf(hklambda.resType)
2425

2526

compiler/src/dotty/tools/dotc/typer/Namer.scala

+25-6
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ class Namer { typer: Typer =>
5454

5555
import untpd.*
5656

57-
val TypedAhead : Property.Key[tpd.Tree] = new Property.Key
58-
val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key
59-
val ExportForwarders: Property.Key[List[tpd.MemberDef]] = new Property.Key
60-
val SymOfTree : Property.Key[Symbol] = new Property.Key
61-
val AttachedDeriver : Property.Key[Deriver] = new Property.Key
57+
val TypedAhead : Property.Key[tpd.Tree] = new Property.Key
58+
val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key
59+
val ExportForwarders : Property.Key[List[tpd.MemberDef]] = new Property.Key
60+
val ParentRefinements: Property.Key[List[Symbol]] = new Property.Key
61+
val SymOfTree : Property.Key[Symbol] = new Property.Key
62+
val AttachedDeriver : Property.Key[Deriver] = new Property.Key
6263
// was `val Deriver`, but that gave shadowing problems with constructor proxies
6364

6465
/** A partial map from unexpanded member and pattern defs and to their expansions.
@@ -1485,6 +1486,7 @@ class Namer { typer: Typer =>
14851486
/** The type signature of a ClassDef with given symbol */
14861487
override def completeInCreationContext(denot: SymDenotation): Unit = {
14871488
val parents = impl.parents
1489+
val parentRefinements = new mutable.LinkedHashMap[Name, Type]
14881490

14891491
/* The type of a parent constructor. Types constructor arguments
14901492
* only if parent type contains uninstantiated type parameters.
@@ -1536,7 +1538,8 @@ class Namer { typer: Typer =>
15361538
val ptype = parentType(parent)(using completerCtx.superCallContext).dealiasKeepAnnots
15371539
if (cls.isRefinementClass) ptype
15381540
else {
1539-
val pt = checkClassType(ptype, parent.srcPos,
1541+
val pt = checkClassType(
1542+
ptype.separateRefinements(cls, parentRefinements), parent.srcPos,
15401543
traitReq = parent ne parents.head, stablePrefixReq = true)
15411544
if (pt.derivesFrom(cls)) {
15421545
val addendum = parent match {
@@ -1564,6 +1567,21 @@ class Namer { typer: Typer =>
15641567
}
15651568
}
15661569

1570+
/** Enter all parent refinements as public class members, unless a definition
1571+
* with the same name already exists in the class.
1572+
*/
1573+
def enterParentRefinementSyms(refinements: List[(Name, Type)]) =
1574+
val refinedSyms = mutable.ListBuffer[Symbol]()
1575+
for (name, tp) <- refinements do
1576+
if decls.lookupEntry(name) == null then
1577+
val flags = tp match
1578+
case tp: MethodOrPoly => Method | Synthetic | Deferred
1579+
case _ => Synthetic | Deferred
1580+
refinedSyms += newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
1581+
if refinedSyms.nonEmpty then
1582+
typr.println(i"parent refinement symbols: ${refinedSyms.toList}")
1583+
original.pushAttachment(ParentRefinements, refinedSyms.toList)
1584+
15671585
/** If `parents` contains references to traits that have supertraits with implicit parameters
15681586
* add those supertraits in linearization order unless they are already covered by other
15691587
* parent types. For instance, in
@@ -1632,6 +1650,7 @@ class Namer { typer: Typer =>
16321650
cls.invalidateMemberCaches() // we might have checked for a member when parents were not known yet.
16331651
cls.setNoInitsFlags(parentsKind(parents), untpd.bodyKind(rest))
16341652
cls.setStableConstructor()
1653+
enterParentRefinementSyms(parentRefinements.toList)
16351654
processExports(using localCtx)
16361655
defn.patchStdLibClass(cls)
16371656
addConstructorProxies(cls)

compiler/src/dotty/tools/dotc/typer/Typer.scala

+18-2
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
911911
if (templ1.parents.isEmpty &&
912912
isFullyDefined(pt, ForceDegree.flipBottom) &&
913913
isSkolemFree(pt) &&
914-
isEligible(pt.underlyingClassRef(refinementOK = false)))
914+
isEligible(pt.underlyingClassRef(refinementOK = true)))
915915
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
916916
for case parent: RefTree <- templ1.parents do
917917
typedAhead(parent, tree => inferTypeParams(typedType(tree), pt))
@@ -2759,6 +2759,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
27592759
}
27602760
}
27612761

2762+
/** Add all parent refinement symbols as declarations to this class */
2763+
def addParentRefinements(body: List[Tree])(using Context): List[Tree] =
2764+
cdef.getAttachment(ParentRefinements) match
2765+
case Some(refinedSyms) =>
2766+
val refinements = refinedSyms.map: sym =>
2767+
( if sym.isType then TypeDef(sym.asType)
2768+
else if sym.is(Method) then DefDef(sym.asTerm)
2769+
else ValDef(sym.asTerm)
2770+
).withSpan(impl.span.startPos)
2771+
body ++ refinements
2772+
case None =>
2773+
body
2774+
27622775
ensureCorrectSuperClass()
27632776
completeAnnotations(cdef, cls)
27642777
val constr1 = typed(constr).asInstanceOf[DefDef]
@@ -2779,7 +2792,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
27792792
cdef.withType(UnspecifiedErrorType)
27802793
else {
27812794
val dummy = localDummy(cls, impl)
2782-
val body1 = addAccessorDefs(cls, typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
2795+
val body1 =
2796+
addParentRefinements(
2797+
addAccessorDefs(cls,
2798+
typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1))
27832799

27842800
checkNoDoubleDeclaration(cls)
27852801
val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1)

tests/neg/i0248-inherit-refined.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
object test {
22
class A { type T }
33
type X = A { type T = Int }
4-
class B extends X // error
4+
class B extends X // was error, now OK
55
type Y = A & B
66
class C extends Y // error
77
type Z = A | B
88
class D extends Z // error
9-
abstract class E extends ({ val x: Int }) // error
9+
abstract class E extends ({ val x: Int }) // was error, now OK
1010
}
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-- [E164] Declaration Error: tests/neg/parent-refinement-access.scala:4:6 ----------------------------------------------
2+
4 |trait Year2(private[Year2] val value: Int) extends (Gen { val x: Int }) // error
3+
| ^
4+
| error overriding value x in trait Year2 of type Int;
5+
| value x in trait Gen of type Any has weaker access privileges; it should be public
6+
| (Note that value x in trait Year2 of type Int is abstract,
7+
| and is therefore overridden by concrete value x in trait Gen of type Any)
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
trait Gen:
2+
private[Gen] val x: Any = ()
3+
4+
trait Year2(private[Year2] val value: Int) extends (Gen { val x: Int }) // error

tests/neg/parent-refinement.check

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,25 @@
1-
-- Error: tests/neg/parent-refinement.scala:5:2 ------------------------------------------------------------------------
2-
5 | with Ordered[Year] { // error
3-
| ^^^^
4-
| end of toplevel definition expected but 'with' found
1+
-- Error: tests/neg/parent-refinement.scala:10:6 -----------------------------------------------------------------------
2+
10 |class Bar extends IdOf[Int], (X { type Value = String }) // error
3+
| ^^^
4+
|class Bar cannot be instantiated since it has a member Value with possibly conflicting bounds Int | String <: ... <: Int & String
5+
-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:14:17 -------------------------------------------------
6+
14 | val x: Value = 0 // error
7+
| ^
8+
| Found: (0 : Int)
9+
| Required: Baz.this.Value
10+
|
11+
| longer explanation available when compiling with `-explain`
12+
-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:20:6 --------------------------------------------------
13+
20 | foo(2) // error
14+
| ^
15+
| Found: (2 : Int)
16+
| Required: Boolean
17+
|
18+
| longer explanation available when compiling with `-explain`
19+
-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:16:22 -------------------------------------------------
20+
16 |val x: IdOf[Int] = Baz() // error
21+
| ^^^^^
22+
| Found: Baz
23+
| Required: IdOf[Int]
24+
|
25+
| longer explanation available when compiling with `-explain`

tests/neg/parent-refinement.scala

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11

22
trait Id { type Value }
3+
trait X { type Value }
4+
type IdOf[T] = Id { type Value = T }
5+
36
case class Year(value: Int) extends AnyVal
4-
with Id { type Value = Int }
5-
with Ordered[Year] { // error
7+
with (Id { type Value = Int })
8+
with Ordered[Year]
9+
10+
class Bar extends IdOf[Int], (X { type Value = String }) // error
11+
12+
class Baz extends IdOf[Int]:
13+
type Value = String
14+
val x: Value = 0 // error
15+
16+
val x: IdOf[Int] = Baz() // error
617

7-
}
18+
object Clash extends ({ def foo(x: Int): Int }):
19+
def foo(x: Boolean): Int = 1
20+
foo(2) // error

tests/pos/parent-refinement.scala

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
class A
2+
class B extends A
3+
class C extends B
4+
5+
trait Id { type Value }
6+
type IdOf[T] = Id { type Value = T }
7+
trait X { type Value }
8+
9+
case class Year(value: Int) extends IdOf[Int]:
10+
val x: Value = 2
11+
12+
type Between[Lo, Hi] = X { type Value >: Lo <: Hi }
13+
14+
class Foo() extends IdOf[B], Between[C, A]:
15+
val x: Value = B()
16+
17+
trait Bar extends IdOf[Int], (X { type Value = String })
18+
19+
class Baz extends IdOf[Int]:
20+
type Value = String
21+
val x: Value = ""
22+
23+
trait Gen:
24+
type T
25+
val x: T
26+
27+
type IntInst = Gen:
28+
type T = Int
29+
val x: 0
30+
31+
trait IntInstTrait extends IntInst
32+
33+
abstract class IntInstClass extends IntInstTrait, IntInst
34+
35+
object obj1 extends IntInstTrait:
36+
val x = 0
37+
38+
object obj2 extends IntInstClass:
39+
val x = 0
40+
41+
def main =
42+
val x: obj1.T = 2 - obj2.x
43+
val y: obj2.T = 2 - obj1.x
44+
45+
46+

0 commit comments

Comments
 (0)