Skip to content

Commit 0f4aba8

Browse files
committed
Introduce tracked class parameters
For a tracked class parameter we add a refinement in the constructor type that the class member is the same as the parameter. E.g. ```scala class C { type T } class D(tracked val x: C) { type T = x.T } ``` This will generate the constructor type: ```scala (x1: C): D { val x: x1.type } ``` Without `tracked` the refinement would not be added. This can solve several problems with dependent class types where previously we lost track of type dependencies.
1 parent bfebebd commit 0f4aba8

15 files changed

+63
-38
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+18-10
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,13 @@ object desugar {
437437
private def toDefParam(tparam: TypeDef, keepAnnotations: Boolean): TypeDef = {
438438
var mods = tparam.rawMods
439439
if (!keepAnnotations) mods = mods.withAnnotations(Nil)
440-
tparam.withMods(mods & (EmptyFlags | Sealed) | Param)
440+
tparam.withMods(mods & EmptyFlags | Param)
441441
}
442442
private def toDefParam(vparam: ValDef, keepAnnotations: Boolean, keepDefault: Boolean): ValDef = {
443443
var mods = vparam.rawMods
444444
if (!keepAnnotations) mods = mods.withAnnotations(Nil)
445445
val hasDefault = if keepDefault then HasDefault else EmptyFlags
446-
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault) | Param)
446+
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault | Tracked) | Param)
447447
}
448448

449449
def mkApply(fn: Tree, paramss: List[ParamClause])(using Context): Tree =
@@ -529,7 +529,7 @@ object desugar {
529529
// but not on the constructor parameters. The reverse is true for
530530
// annotations on class _value_ parameters.
531531
val constrTparams = impliedTparams.map(toDefParam(_, keepAnnotations = false))
532-
val constrVparamss =
532+
def defVparamss =
533533
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
534534
if (isCaseClass)
535535
report.error(CaseClassMissingParamList(cdef), namePos)
@@ -540,6 +540,7 @@ object desugar {
540540
ListOfNil
541541
}
542542
else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = true, keepDefault = true))
543+
val constrVparamss = defVparamss
543544
val derivedTparams =
544545
constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) =>
545546
derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations))
@@ -614,6 +615,11 @@ object desugar {
614615
case _ => false
615616
}
616617

618+
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
619+
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
620+
case _ => false
621+
}
622+
617623
def appliedRef(tycon: Tree, tparams: List[TypeDef] = constrTparams, widenHK: Boolean = false) = {
618624
val targs = for (tparam <- tparams) yield {
619625
val targ = refOfDef(tparam)
@@ -630,10 +636,13 @@ object desugar {
630636
appliedTypeTree(tycon, targs)
631637
}
632638

633-
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
634-
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
635-
case _ => false
636-
}
639+
def addParamRefinements(core: Tree, paramss: List[List[ValDef]]): Tree =
640+
val refinements =
641+
for params <- paramss; param <- params; if param.mods.is(Tracked) yield
642+
ValDef(param.name, SingletonTypeTree(TermRefTree().watching(param)), EmptyTree)
643+
.withSpan(param.span)
644+
if refinements.isEmpty then core
645+
else RefinedTypeTree(core, refinements).showing(i"refined result: $result", Printers.desugar)
637646

638647
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
639648
val classTypeRef = appliedRef(classTycon)
@@ -854,18 +863,17 @@ object desugar {
854863
Nil
855864
}
856865
else {
857-
val defParamss = constrVparamss match {
866+
val defParamss = defVparamss match
858867
case Nil :: paramss =>
859868
paramss // drop leading () that got inserted by class
860869
// TODO: drop this once we do not silently insert empty class parameters anymore
861870
case paramss => paramss
862-
}
863871
val finalFlag = if ctx.settings.YcompileScala2Library.value then EmptyFlags else Final
864872
// implicit wrapper is typechecked in same scope as constructor, so
865873
// we can reuse the constructor parameters; no derived params are needed.
866874
DefDef(
867875
className.toTermName, joinParams(constrTparams, defParamss),
868-
classTypeRef, creatorExpr)
876+
addParamRefinements(classTypeRef, defParamss), creatorExpr)
869877
.withMods(companionMods | mods.flags.toTermFlags & (GivenOrImplicit | Inline) | finalFlag)
870878
.withSpan(cdef.span) :: Nil
871879
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

+2
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
231231

232232
case class Infix()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Infix)
233233

234+
case class Tracked()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Tracked)
235+
234236
/** Used under pureFunctions to mark impure function types `A => B` in `FunctionWithMods` */
235237
case class Impure()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Impure)
236238
}

compiler/src/dotty/tools/dotc/core/Flags.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ object Flags {
242242
val (AccessorOrSealed @ _, Accessor @ _, Sealed @ _) = newFlags(11, "<accessor>", "sealed")
243243

244244
/** A mutable var, an open class */
245-
val (MutableOrOpen @ __, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open")
245+
val (MutableOrOpen @ _, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open")
246246

247247
/** Symbol is local to current class (i.e. private[this] or protected[this]
248248
* pre: Private or Protected are also set
@@ -377,6 +377,8 @@ object Flags {
377377
/** Symbol cannot be found as a member during typer */
378378
val (Invisible @ _, _, _) = newFlags(45, "<invisible>")
379379

380+
val (Tracked @ _, _, _) = newFlags(46, "tracked")
381+
380382
// ------------ Flags following this one are not pickled ----------------------------------
381383

382384
/** Symbol is not a member of its owner */
@@ -452,7 +454,7 @@ object Flags {
452454
CommonSourceModifierFlags.toTypeFlags | Abstract | Sealed | Opaque | Open
453455

454456
val TermSourceModifierFlags: FlagSet =
455-
CommonSourceModifierFlags.toTermFlags | Inline | AbsOverride | Lazy
457+
CommonSourceModifierFlags.toTermFlags | Inline | AbsOverride | Lazy | Tracked
456458

457459
/** Flags representing modifiers that can appear in trees */
458460
val ModifierFlags: FlagSet =

compiler/src/dotty/tools/dotc/core/NamerOps.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ object NamerOps:
1515
*/
1616
def effectiveResultType(ctor: Symbol, paramss: List[List[Symbol]])(using Context): Type =
1717
paramss match
18-
case TypeSymbols(tparams) :: _ => ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef))
18+
case TypeSymbols(tparams) :: _ =>
19+
var resType = ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef))
20+
for params <- paramss; param <- params do
21+
if param.is(Tracked) then
22+
resType = RefinedType(resType, param.name, param.termRef)
23+
resType
1924
case _ => ctor.owner.typeRef
2025

2126
/** If isConstructor, make sure it has at least one non-implicit parameter list

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ object StdNames {
623623
val toString_ : N = "toString"
624624
val toTypeConstructor: N = "toTypeConstructor"
625625
val tpe : N = "tpe"
626+
val tracked: N = "tracked"
626627
val transparent : N = "transparent"
627628
val tree : N = "tree"
628629
val true_ : N = "true"

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ class TreePickler(pickler: TastyPickler) {
787787
if (flags.is(Extension)) writeModTag(EXTENSION)
788788
if (flags.is(ParamAccessor)) writeModTag(PARAMsetter)
789789
if (flags.is(SuperParamAlias)) writeModTag(PARAMalias)
790+
if (flags.is(Tracked)) writeModTag(TRACKED)
790791
assert(!(flags.is(Label)))
791792
}
792793
else {

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ class TreeUnpickler(reader: TastyReader,
733733
case INVISIBLE => addFlag(Invisible)
734734
case TRANSPARENT => addFlag(Transparent)
735735
case INFIX => addFlag(Infix)
736+
case TRACKED => addFlag(Tracked)
736737
case PRIVATEqualified =>
737738
readByte()
738739
privateWithin = readWithin

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -3100,6 +3100,7 @@ object Parsers {
31003100
case nme.open => Mod.Open()
31013101
case nme.transparent => Mod.Transparent()
31023102
case nme.infix => Mod.Infix()
3103+
case nme.tracked => Mod.Tracked()
31033104
}
31043105
}
31053106

@@ -3166,6 +3167,7 @@ object Parsers {
31663167
* | AccessModifier
31673168
* | override
31683169
* | opaque
3170+
* | tracked
31693171
* LocalModifier ::= abstract | final | sealed | open | implicit | lazy | erased | inline | transparent
31703172
*/
31713173
def modifiers(allowed: BitSet = modifierTokens, start: Modifiers = Modifiers()): Modifiers = {
@@ -3427,7 +3429,8 @@ object Parsers {
34273429
val isParams =
34283430
!impliedMods.is(Given)
34293431
|| startParamTokens.contains(in.token)
3430-
|| isIdent && (in.name == nme.inline || in.lookahead.isColon)
3432+
|| isIdent
3433+
&& (in.name == nme.inline || in.name == nme.tracked || in.lookahead.isColon)
34313434
(mods, isParams)
34323435
(if isParams then commaSeparated(() => param())
34333436
else contextTypes(paramOwner, numLeadParams, impliedMods)) match {
@@ -4005,7 +4008,9 @@ object Parsers {
40054008
paramss.nestedMap: param =>
40064009
if !param.mods.isAllOf(PrivateLocal) then
40074010
syntaxError(em"method parameter ${param.name} may not be `a val`", param.span)
4008-
param.withMods(param.mods &~ (AccessFlags | ParamAccessor | Mutable) | Param)
4011+
if param.mods.is(Tracked) then
4012+
syntaxError(em"method parameter ${param.name} may not be `tracked`", param.span)
4013+
param.withMods(param.mods &~ (AccessFlags | ParamAccessor | Tracked | Mutable) | Param)
40094014
.asInstanceOf[List[ParamClause]]
40104015

40114016
val gdef =

compiler/src/dotty/tools/dotc/parsing/Tokens.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ object Tokens extends TokensCommon {
294294

295295
final val closingParens = BitSet(RPAREN, RBRACKET, RBRACE)
296296

297-
final val softModifierNames = Set(nme.inline, nme.into, nme.opaque, nme.open, nme.transparent, nme.infix)
297+
final val softModifierNames = Set(nme.inline, nme.into, nme.opaque, nme.open, nme.transparent, nme.infix, nme.tracked)
298298

299299
def showTokenDetailed(token: Int): String = debugString(token)
300300

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,11 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
332332
case Select(nu: New, nme.CONSTRUCTOR) if isCheckable(nu) =>
333333
// need to check instantiability here, because the type of the New itself
334334
// might be a type constructor.
335-
ctx.typer.checkClassType(tree.tpe, tree.srcPos, traitReq = false, stablePrefixReq = true)
335+
ctx.typer.checkClassType(tree.tpe, tree.srcPos, traitReq = false, stablePrefixReq = true, refinementOK = true)
336336
if !nu.tpe.isLambdaSub then
337337
// Check the constructor type as well; it could be an illegal singleton type
338338
// which would not be reflected as `tree.tpe`
339-
ctx.typer.checkClassType(nu.tpe, tree.srcPos, traitReq = false, stablePrefixReq = false)
339+
ctx.typer.checkClassType(nu.tpe, tree.srcPos, traitReq = false, stablePrefixReq = false, refinementOK = true)
340340
Checking.checkInstantiable(tree.tpe, nu.tpe, nu.srcPos)
341341
withNoCheckNews(nu :: Nil)(app1)
342342
case _ =>

compiler/src/dotty/tools/dotc/typer/Checking.scala

+3-4
Original file line numberDiff line numberDiff line change
@@ -1021,16 +1021,15 @@ trait Checking {
10211021
* check that class prefix is stable.
10221022
* @return `tp` itself if it is a class or trait ref, ObjectType if not.
10231023
*/
1024-
def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean)(using Context): Type =
1025-
tp.underlyingClassRef(refinementOK = false) match {
1024+
def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean, refinementOK: Boolean = false)(using Context): Type =
1025+
tp.underlyingClassRef(refinementOK) match
10261026
case tref: TypeRef =>
10271027
if (traitReq && !tref.symbol.is(Trait)) report.error(TraitIsExpected(tref.symbol), pos)
10281028
if (stablePrefixReq && ctx.phase <= refchecksPhase) checkStable(tref.prefix, pos, "class prefix")
10291029
tp
10301030
case _ =>
10311031
report.error(NotClassType(tp), pos)
10321032
defn.ObjectType
1033-
}
10341033

10351034
/** If `sym` is an old-style implicit conversion, check that implicit conversions are enabled.
10361035
* @pre sym.is(GivenOrImplicit)
@@ -1601,7 +1600,7 @@ trait NoChecking extends ReChecking {
16011600
override def checkNonCyclic(sym: Symbol, info: TypeBounds, reportErrors: Boolean)(using Context): Type = info
16021601
override def checkNonCyclicInherited(joint: Type, parents: List[Type], decls: Scope, pos: SrcPos)(using Context): Unit = ()
16031602
override def checkStable(tp: Type, pos: SrcPos, kind: String)(using Context): Unit = ()
1604-
override def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean)(using Context): Type = tp
1603+
override def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean, refinementOK: Boolean)(using Context): Type = tp
16051604
override def checkImplicitConversionDefOK(sym: Symbol)(using Context): Unit = ()
16061605
override def checkImplicitConversionUseOK(tree: Tree, expected: Type)(using Context): Unit = ()
16071606
override def checkFeasibleParent(tp: Type, pos: SrcPos, where: => String = "")(using Context): Type = tp

compiler/test/dotc/pos-test-pickling.blacklist

+3
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,6 @@ java-inherited-type1
111111

112112
# recursion limit exceeded
113113
i7445b.scala
114+
115+
# alias types at different levels of dereferencing
116+
parsercombinators-givens.scala

tasty/src/dotty/tools/tasty/TastyFormat.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ object TastyFormat {
485485
final val INVISIBLE = 44
486486
final val EMPTYCLAUSE = 45
487487
final val SPLITCLAUSE = 46
488+
final val TRACKED = 47
488489

489490
// Cat. 2: tag Nat
490491

@@ -662,7 +663,8 @@ object TastyFormat {
662663
| INVISIBLE
663664
| ANNOTATION
664665
| PRIVATEqualified
665-
| PROTECTEDqualified => true
666+
| PROTECTEDqualified
667+
| TRACKED => true
666668
case _ => false
667669
}
668670

tests/pos/parsercombinators-expanded.scala

+6-11
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ object test:
2727
def apply[C, E]: apply[C, E] = new apply[C, E]
2828

2929
class combine[A, B](
30-
val f: Combinator[A],
31-
val s: Combinator[B] { type Context = f.Context}
30+
tracked val f: Combinator[A],
31+
tracked val s: Combinator[B] { type Context = f.Context}
3232
) extends Combinator[Combine[A, B]]:
3333
type Context = f.Context
3434
type Element = (f.Element, s.Element)
@@ -38,10 +38,7 @@ object test:
3838
def combine[A, B](
3939
_f: Combinator[A],
4040
_s: Combinator[B] { type Context = _f.Context}
41-
): combine[A, B] {
42-
type Context = _f.Context
43-
type Element = (_f.Element, _s.Element)
44-
} = new combine[A, B](_f, _s).asInstanceOf
41+
) = new combine[A, B](_f, _s)
4542
// cast is needed since the type of new combine[A, B](_f, _s)
4643
// drops the required refinement.
4744

@@ -56,12 +53,10 @@ object test:
5653
val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
5754
val m = Combine(n, n)
5855

59-
val c = combine[
60-
Apply[mutable.ListBuffer[Int], Int],
61-
Apply[mutable.ListBuffer[Int], Int]
62-
](
56+
val c = combine(
6357
apply[mutable.ListBuffer[Int], Int],
6458
apply[mutable.ListBuffer[Int], Int]
6559
)
66-
val r = c.parse(m)(stream) // type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context`
60+
val r = c.parse(m)(stream) // was type mismatch, now OK
61+
val rc: Option[(Int, Int)] = r
6762
}

tests/pos/parsercombinators-givens.scala

+5-4
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ given apply[C, E]: Combinator[Apply[C, E]] with {
2424
}
2525
}
2626

27-
given combine[A, B, C](using
28-
f: Combinator[A] { type Context = C },
29-
s: Combinator[B] { type Context = C }
27+
given combine[A, B](using
28+
tracked val f: Combinator[A],
29+
tracked val s: Combinator[B] { type Context = f.Context }
3030
): Combinator[Combine[A, B]] with {
3131
type Context = f.Context
3232
type Element = (f.Element, s.Element)
@@ -46,6 +46,7 @@ extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
4646
val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
4747
val m = Combine(n, n)
4848

49-
// val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context`
49+
val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context`
50+
val rc: Option[(Int, Int)] = r
5051
// it would be great if this worked
5152
}

0 commit comments

Comments
 (0)