Skip to content

Allow class parents to be refined types. #19256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions compiler/src/dotty/tools/dotc/core/NamerOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package core
import Contexts.*, Symbols.*, Types.*, Flags.*, Scopes.*, Decorators.*, Names.*, NameOps.*
import SymDenotations.{LazyType, SymDenotation}, StdNames.nme
import TypeApplications.EtaExpansion
import collection.mutable

/** Operations that are shared between Namer and TreeUnpickler */
object NamerOps:
Expand All @@ -18,6 +19,26 @@ object NamerOps:
case TypeSymbols(tparams) :: _ => ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef))
case _ => ctor.owner.typeRef

/** Split dependent class refinements off parent type. Add them to `refinements`,
* unless it is null.
*/
extension (tp: Type)
def separateRefinements(cls: ClassSymbol, refinements: mutable.LinkedHashMap[Name, Type] | Null)(using Context): Type =
tp match
case RefinedType(tp1, rname, rinfo) =>
try tp1.separateRefinements(cls, refinements)
finally
if refinements != null then
refinements(rname) = refinements.get(rname) match
case Some(tp) => tp & rinfo
case None => rinfo
case tp @ AnnotatedType(tp1, ann) =>
tp.derivedAnnotatedType(tp1.separateRefinements(cls, refinements), ann)
case tp: RecType =>
tp.parent.substRecThis(tp, cls.thisType).separateRefinements(cls, refinements)
case tp =>
tp

/** If isConstructor, make sure it has at least one non-implicit parameter list
* This is done by adding a () in front of a leading old style implicit parameter,
* or by adding a () as last -- or only -- parameter list if the constructor has
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ class TreeUnpickler(reader: TastyReader,
}
val parentReader = fork
val parents = readParents(withArgs = false)(using parentCtx)
val parentTypes = parents.map(_.tpe.dealias)
val parentTypes = parents.map(_.tpe.dealiasKeepAnnots.separateRefinements(cls, null))
if cls.is(JavaDefined) && parentTypes.exists(_.derivesFrom(defn.JavaAnnotationClass)) then
cls.setFlag(JavaAnnotation)
val self =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/transform/init/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object Util:

def typeRefOf(tp: Type)(using Context): TypeRef = tp.dealias.typeConstructor match
case tref: TypeRef => tref
case RefinedType(parent, _, _) => typeRefOf(parent)
case hklambda: HKTypeLambda => typeRefOf(hklambda.resType)


Expand Down
31 changes: 25 additions & 6 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ class Namer { typer: Typer =>

import untpd.*

val TypedAhead : Property.Key[tpd.Tree] = new Property.Key
val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key
val ExportForwarders: Property.Key[List[tpd.MemberDef]] = new Property.Key
val SymOfTree : Property.Key[Symbol] = new Property.Key
val AttachedDeriver : Property.Key[Deriver] = new Property.Key
val TypedAhead : Property.Key[tpd.Tree] = new Property.Key
val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key
val ExportForwarders : Property.Key[List[tpd.MemberDef]] = new Property.Key
val ParentRefinements: Property.Key[List[Symbol]] = new Property.Key
val SymOfTree : Property.Key[Symbol] = new Property.Key
val AttachedDeriver : Property.Key[Deriver] = new Property.Key
// was `val Deriver`, but that gave shadowing problems with constructor proxies

/** A partial map from unexpanded member and pattern defs and to their expansions.
Expand Down Expand Up @@ -1485,6 +1486,7 @@ class Namer { typer: Typer =>
/** The type signature of a ClassDef with given symbol */
override def completeInCreationContext(denot: SymDenotation): Unit = {
val parents = impl.parents
val parentRefinements = new mutable.LinkedHashMap[Name, Type]

/* The type of a parent constructor. Types constructor arguments
* only if parent type contains uninstantiated type parameters.
Expand Down Expand Up @@ -1536,7 +1538,8 @@ class Namer { typer: Typer =>
val ptype = parentType(parent)(using completerCtx.superCallContext).dealiasKeepAnnots
if (cls.isRefinementClass) ptype
else {
val pt = checkClassType(ptype, parent.srcPos,
val pt = checkClassType(
ptype.separateRefinements(cls, parentRefinements), parent.srcPos,
traitReq = parent ne parents.head, stablePrefixReq = true)
if (pt.derivesFrom(cls)) {
val addendum = parent match {
Expand Down Expand Up @@ -1564,6 +1567,21 @@ class Namer { typer: Typer =>
}
}

/** Enter all parent refinements as public class members, unless a definition
* with the same name already exists in the class.
*/
def enterParentRefinementSyms(refinements: List[(Name, Type)]) =
val refinedSyms = mutable.ListBuffer[Symbol]()
for (name, tp) <- refinements do
if decls.lookupEntry(name) == null then
val flags = tp match
case tp: MethodOrPoly => Method | Synthetic | Deferred
case _ => Synthetic | Deferred
refinedSyms += newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
if refinedSyms.nonEmpty then
typr.println(i"parent refinement symbols: ${refinedSyms.toList}")
original.pushAttachment(ParentRefinements, refinedSyms.toList)

/** If `parents` contains references to traits that have supertraits with implicit parameters
* add those supertraits in linearization order unless they are already covered by other
* parent types. For instance, in
Expand Down Expand Up @@ -1632,6 +1650,7 @@ class Namer { typer: Typer =>
cls.invalidateMemberCaches() // we might have checked for a member when parents were not known yet.
cls.setNoInitsFlags(parentsKind(parents), untpd.bodyKind(rest))
cls.setStableConstructor()
enterParentRefinementSyms(parentRefinements.toList)
processExports(using localCtx)
defn.patchStdLibClass(cls)
addConstructorProxies(cls)
Expand Down
20 changes: 18 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if (templ1.parents.isEmpty &&
isFullyDefined(pt, ForceDegree.flipBottom) &&
isSkolemFree(pt) &&
isEligible(pt.underlyingClassRef(refinementOK = false)))
isEligible(pt.underlyingClassRef(refinementOK = true)))
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
for case parent: RefTree <- templ1.parents do
typedAhead(parent, tree => inferTypeParams(typedType(tree), pt))
Expand Down Expand Up @@ -2766,6 +2766,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
}

/** Add all parent refinement symbols as declarations to this class */
def addParentRefinements(body: List[Tree])(using Context): List[Tree] =
cdef.getAttachment(ParentRefinements) match
case Some(refinedSyms) =>
val refinements = refinedSyms.map: sym =>
( if sym.isType then TypeDef(sym.asType)
else if sym.is(Method) then DefDef(sym.asTerm)
else ValDef(sym.asTerm)
).withSpan(impl.span.startPos)
body ++ refinements
case None =>
body

ensureCorrectSuperClass()
completeAnnotations(cdef, cls)
val constr1 = typed(constr).asInstanceOf[DefDef]
Expand All @@ -2786,7 +2799,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
cdef.withType(UnspecifiedErrorType)
else {
val dummy = localDummy(cls, impl)
val body1 = addAccessorDefs(cls, typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
val body1 =
addParentRefinements(
addAccessorDefs(cls,
typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1))

checkNoDoubleDeclaration(cls)
val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1)
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/i0248-inherit-refined.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
object test {
class A { type T }
type X = A { type T = Int }
class B extends X // error
class B extends X // was error, now OK
type Y = A & B
class C extends Y // error
type Z = A | B
class D extends Z // error
abstract class E extends ({ val x: Int }) // error
abstract class E extends ({ val x: Int }) // was error, now OK
}
7 changes: 7 additions & 0 deletions tests/neg/parent-refinement-access.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- [E164] Declaration Error: tests/neg/parent-refinement-access.scala:4:6 ----------------------------------------------
4 |trait Year2(private[Year2] val value: Int) extends (Gen { val x: Int }) // error
| ^
| error overriding value x in trait Year2 of type Int;
| value x in trait Gen of type Any has weaker access privileges; it should be public
| (Note that value x in trait Year2 of type Int is abstract,
| and is therefore overridden by concrete value x in trait Gen of type Any)
4 changes: 4 additions & 0 deletions tests/neg/parent-refinement-access.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
trait Gen:
private[Gen] val x: Any = ()

trait Year2(private[Year2] val value: Int) extends (Gen { val x: Int }) // error
29 changes: 25 additions & 4 deletions tests/neg/parent-refinement.check
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
-- Error: tests/neg/parent-refinement.scala:5:2 ------------------------------------------------------------------------
5 | with Ordered[Year] { // error
| ^^^^
| end of toplevel definition expected but 'with' found
-- Error: tests/neg/parent-refinement.scala:10:6 -----------------------------------------------------------------------
10 |class Bar extends IdOf[Int], (X { type Value = String }) // error
| ^^^
|class Bar cannot be instantiated since it has a member Value with possibly conflicting bounds Int | String <: ... <: Int & String
-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:14:17 -------------------------------------------------
14 | val x: Value = 0 // error
| ^
| Found: (0 : Int)
| Required: Baz.this.Value
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:20:6 --------------------------------------------------
20 | foo(2) // error
| ^
| Found: (2 : Int)
| Required: Boolean
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:16:22 -------------------------------------------------
16 |val x: IdOf[Int] = Baz() // error
| ^^^^^
| Found: Baz
| Required: IdOf[Int]
|
| longer explanation available when compiling with `-explain`
19 changes: 16 additions & 3 deletions tests/neg/parent-refinement.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@

trait Id { type Value }
trait X { type Value }
type IdOf[T] = Id { type Value = T }

case class Year(value: Int) extends AnyVal
with Id { type Value = Int }
with Ordered[Year] { // error
with (Id { type Value = Int })
with Ordered[Year]

class Bar extends IdOf[Int], (X { type Value = String }) // error

class Baz extends IdOf[Int]:
type Value = String
val x: Value = 0 // error

val x: IdOf[Int] = Baz() // error

}
object Clash extends ({ def foo(x: Int): Int }):
def foo(x: Boolean): Int = 1
foo(2) // error
46 changes: 46 additions & 0 deletions tests/pos/parent-refinement.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
class A
class B extends A
class C extends B

trait Id { type Value }
type IdOf[T] = Id { type Value = T }
trait X { type Value }

case class Year(value: Int) extends IdOf[Int]:
val x: Value = 2

type Between[Lo, Hi] = X { type Value >: Lo <: Hi }

class Foo() extends IdOf[B], Between[C, A]:
val x: Value = B()

trait Bar extends IdOf[Int], (X { type Value = String })

class Baz extends IdOf[Int]:
type Value = String
val x: Value = ""

trait Gen:
type T
val x: T

type IntInst = Gen:
type T = Int
val x: 0

trait IntInstTrait extends IntInst

abstract class IntInstClass extends IntInstTrait, IntInst

object obj1 extends IntInstTrait:
val x = 0

object obj2 extends IntInstClass:
val x = 0

def main =
val x: obj1.T = 2 - obj2.x
val y: obj2.T = 2 - obj1.x



Loading