Skip to content

Commit f8bd01c

Browse files
authored
Merge pull request #11830 from dotty-staging/fix-7613
Insert traits with implicit parameters as extra parents of classes
2 parents 46eb23e + 7b7774e commit f8bd01c

File tree

11 files changed

+211
-53
lines changed

11 files changed

+211
-53
lines changed

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

+5
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,10 @@ object TypeUtils {
8080
case self: TypeProxy =>
8181
self.underlying.companionRef
8282
}
83+
84+
/** Is this type a methodic type that takes implicit parameters (both old and new) at some point? */
85+
def takesImplicitParams(using Context): Boolean = self.stripPoly match
86+
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
87+
case _ => false
8388
}
8489
}

compiler/src/dotty/tools/dotc/typer/Namer.scala

+70-1
Original file line numberDiff line numberDiff line change
@@ -1237,12 +1237,81 @@ class Namer { typer: Typer =>
12371237
}
12381238
}
12391239

1240+
/** Ensure that the first type in a list of parent types Ps points to a non-trait class.
1241+
* If that's not already the case, add one. The added class type CT is determined as follows.
1242+
* First, let C be the unique class such that
1243+
* - there is a parent P_i such that P_i derives from C, and
1244+
* - for every class D: If some parent P_j, j <= i derives from D, then C derives from D.
1245+
* Then, let CT be the smallest type which
1246+
* - has C as its class symbol, and
1247+
* - for all parents P_i: If P_i derives from C then P_i <:< CT.
1248+
*/
1249+
def ensureFirstIsClass(parents: List[Type]): List[Type] =
1250+
1251+
def realClassParent(sym: Symbol): ClassSymbol =
1252+
if !sym.isClass then defn.ObjectClass
1253+
else if !sym.is(Trait) then sym.asClass
1254+
else sym.info.parents match
1255+
case parentRef :: _ => realClassParent(parentRef.typeSymbol)
1256+
case nil => defn.ObjectClass
1257+
1258+
def improve(candidate: ClassSymbol, parent: Type): ClassSymbol =
1259+
val pcls = realClassParent(parent.classSymbol)
1260+
if (pcls derivesFrom candidate) pcls else candidate
1261+
1262+
parents match
1263+
case p :: _ if p.classSymbol.isRealClass => parents
1264+
case _ =>
1265+
val pcls = parents.foldLeft(defn.ObjectClass)(improve)
1266+
typr.println(i"ensure first is class $parents%, % --> ${parents map (_ baseType pcls)}%, %")
1267+
val first = TypeComparer.glb(defn.ObjectType :: parents.map(_.baseType(pcls)))
1268+
checkFeasibleParent(first, cls.srcPos, em" in inferred superclass $first") :: parents
1269+
end ensureFirstIsClass
1270+
1271+
/** If `parents` contains references to traits that have supertraits with implicit parameters
1272+
* add those supertraits in linearization order unless they are already covered by other
1273+
* parent types. For instance, in
1274+
*
1275+
* class A
1276+
* trait B(using I) extends A
1277+
* trait C extends B
1278+
* class D extends A, C
1279+
*
1280+
* the class declaration of `D` is augmented to
1281+
*
1282+
* class D extends A, B, C
1283+
*
1284+
* so that an implicit `I` can be passed to `B`. See i7613.scala for more examples.
1285+
*/
1286+
def addUsingTraits(parents: List[Type]): List[Type] =
1287+
lazy val existing = parents.map(_.classSymbol).toSet
1288+
def recur(parents: List[Type]): List[Type] = parents match
1289+
case parent :: parents1 =>
1290+
val psym = parent.classSymbol
1291+
val addedTraits =
1292+
if psym.is(Trait) then
1293+
psym.asClass.baseClasses.tail.iterator
1294+
.takeWhile(_.is(Trait))
1295+
.filter(p =>
1296+
p.primaryConstructor.info.takesImplicitParams
1297+
&& !cls.superClass.isSubClass(p)
1298+
&& !existing.contains(p))
1299+
.toList.reverse
1300+
else Nil
1301+
addedTraits.map(parent.baseType) ::: parent :: recur(parents1)
1302+
case nil =>
1303+
Nil
1304+
if cls.isRealClass then recur(parents) else parents
1305+
end addUsingTraits
1306+
12401307
completeConstructor(denot)
12411308
denot.info = tempInfo
12421309

12431310
val parentTypes = defn.adjustForTuple(cls, cls.typeParams,
12441311
defn.adjustForBoxedUnit(cls,
1245-
ensureFirstIsClass(parents.map(checkedParentType(_)), cls.span)
1312+
addUsingTraits(
1313+
ensureFirstIsClass(parents.map(checkedParentType(_)))
1314+
)
12461315
)
12471316
)
12481317
typr.println(i"completing $denot, parents = $parents%, %, parentTypes = $parentTypes%, %")

compiler/src/dotty/tools/dotc/typer/ReTyper.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ class ReTyper extends Typer with ReChecking {
103103

104104
override def completeAnnotations(mdef: untpd.MemberDef, sym: Symbol)(using Context): Unit = ()
105105

106-
override def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] =
107-
parents
106+
override def ensureConstrCall(cls: ClassSymbol, parent: Tree)(using Context): Tree =
107+
parent
108108

109109
override def handleUnexpectedFunType(tree: untpd.Apply, fun: Tree)(using Context): Tree = fun.tpe match {
110110
case mt: MethodType =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

+41-49
Original file line numberDiff line numberDiff line change
@@ -2219,20 +2219,18 @@ class Typer extends Namer
22192219
* @param psym Its type symbol
22202220
* @param cinfo The info of its constructor
22212221
*/
2222-
def maybeCall(ref: Tree, psym: Symbol, cinfo: Type): Tree = cinfo.stripPoly match {
2222+
def maybeCall(ref: Tree, psym: Symbol): Tree = psym.primaryConstructor.info.stripPoly match
22232223
case cinfo @ MethodType(Nil) if cinfo.resultType.isImplicitMethod =>
22242224
typedExpr(untpd.New(untpd.TypedSplice(ref)(using superCtx), Nil))(using superCtx)
22252225
case cinfo @ MethodType(Nil) if !cinfo.resultType.isInstanceOf[MethodType] =>
22262226
ref
22272227
case cinfo: MethodType =>
2228-
if (!ctx.erasedTypes) { // after constructors arguments are passed in super call.
2228+
if !ctx.erasedTypes then // after constructors arguments are passed in super call.
22292229
typr.println(i"constr type: $cinfo")
22302230
report.error(ParameterizedTypeLacksArguments(psym), ref.srcPos)
2231-
}
22322231
ref
22332232
case _ =>
22342233
ref
2235-
}
22362234

22372235
val seenParents = mutable.Set[Symbol]()
22382236

@@ -2257,14 +2255,35 @@ class Typer extends Namer
22572255
if (tree.isType) {
22582256
checkSimpleKinded(result) // Not needed for constructor calls, as type arguments will be inferred.
22592257
if (psym.is(Trait) && !cls.is(Trait) && !cls.superClass.isSubClass(psym))
2260-
result = maybeCall(result, psym, psym.primaryConstructor.info)
2258+
result = maybeCall(result, psym)
22612259
}
22622260
else checkParentCall(result, cls)
22632261
checkTraitInheritance(psym, cls, tree.srcPos)
22642262
if (cls is Case) checkCaseInheritance(psym, cls, tree.srcPos)
22652263
result
22662264
}
22672265

2266+
/** Augment `ptrees` to have the same class symbols as `parents`. Generate TypeTrees
2267+
* or New trees to fill in any parents for which no tree exists yet.
2268+
*/
2269+
def parentTrees(parents: List[Type], ptrees: List[Tree]): List[Tree] = parents match
2270+
case parent :: parents1 =>
2271+
val psym = parent.classSymbol
2272+
def hasSameParent(ptree: Tree) = ptree.tpe.classSymbol == psym
2273+
ptrees match
2274+
case ptree :: ptrees1 if hasSameParent(ptree) =>
2275+
ptree :: parentTrees(parents1, ptrees1)
2276+
case ptree :: ptrees1 if ptrees1.exists(hasSameParent) =>
2277+
ptree :: parentTrees(parents, ptrees1)
2278+
case _ =>
2279+
var added: Tree = TypeTree(parent).withSpan(cdef.nameSpan.focus)
2280+
if psym.is(Trait) && psym.primaryConstructor.info.takesImplicitParams then
2281+
// classes get a constructor separately using a different context
2282+
added = ensureConstrCall(cls, added)
2283+
added :: parentTrees(parents1, ptrees)
2284+
case _ =>
2285+
ptrees
2286+
22682287
/** Checks if one of the decls is a type with the same name as class type member in selfType */
22692288
def classExistsOnSelf(decls: Scope, self: tpd.ValDef): Boolean = {
22702289
val selfType = self.tpt.tpe
@@ -2285,8 +2304,10 @@ class Typer extends Namer
22852304

22862305
completeAnnotations(cdef, cls)
22872306
val constr1 = typed(constr).asInstanceOf[DefDef]
2288-
val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan)
2289-
val parents1 = ensureConstrCall(cls, parentsWithClass)(using superCtx)
2307+
val parents0 = parentTrees(
2308+
cls.classInfo.declaredParents,
2309+
parents.mapconserve(typedParent).filterConserve(!_.isEmpty))
2310+
val parents1 = ensureConstrCall(cls, parents0)(using superCtx)
22902311
val firstParentTpe = parents1.head.tpe.dealias
22912312
val firstParent = firstParentTpe.typeSymbol
22922313

@@ -2355,52 +2376,23 @@ class Typer extends Namer
23552376
protected def addAccessorDefs(cls: Symbol, body: List[Tree])(using Context): List[Tree] =
23562377
ctx.compilationUnit.inlineAccessors.addAccessorDefs(cls, body)
23572378

2358-
/** Ensure that the first type in a list of parent types Ps points to a non-trait class.
2359-
* If that's not already the case, add one. The added class type CT is determined as follows.
2360-
* First, let C be the unique class such that
2361-
* - there is a parent P_i such that P_i derives from C, and
2362-
* - for every class D: If some parent P_j, j <= i derives from D, then C derives from D.
2363-
* Then, let CT be the smallest type which
2364-
* - has C as its class symbol, and
2365-
* - for all parents P_i: If P_i derives from C then P_i <:< CT.
2379+
/** If this is a real class, make sure its first parent is a
2380+
* constructor call. Cannot simply use a type. Overridden in ReTyper.
23662381
*/
2367-
def ensureFirstIsClass(parents: List[Type], span: Span)(using Context): List[Type] = {
2368-
def realClassParent(cls: Symbol): ClassSymbol =
2369-
if (!cls.isClass) defn.ObjectClass
2370-
else if (!cls.is(Trait)) cls.asClass
2371-
else cls.info.parents match {
2372-
case parentRef :: _ => realClassParent(parentRef.typeSymbol)
2373-
case nil => defn.ObjectClass
2374-
}
2375-
def improve(candidate: ClassSymbol, parent: Type): ClassSymbol = {
2376-
val pcls = realClassParent(parent.classSymbol)
2377-
if (pcls derivesFrom candidate) pcls else candidate
2378-
}
2379-
parents match {
2380-
case p :: _ if p.classSymbol.isRealClass => parents
2381-
case _ =>
2382-
val pcls = parents.foldLeft(defn.ObjectClass)(improve)
2383-
typr.println(i"ensure first is class $parents%, % --> ${parents map (_ baseType pcls)}%, %")
2384-
val first = TypeComparer.glb(defn.ObjectType :: parents.map(_.baseType(pcls)))
2385-
checkFeasibleParent(first, ctx.source.atSpan(span), em" in inferred superclass $first") :: parents
2386-
}
2387-
}
2382+
def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = parents match
2383+
case parents @ (first :: others) =>
2384+
parents.derivedCons(ensureConstrCall(cls, first), others)
2385+
case parents =>
2386+
parents
23882387

2389-
/** Ensure that first parent tree refers to a real class. */
2390-
def ensureFirstTreeIsClass(parents: List[Tree], span: Span)(using Context): List[Tree] = parents match {
2391-
case p :: ps if p.tpe.classSymbol.isRealClass => parents
2392-
case _ => TypeTree(ensureFirstIsClass(parents.tpes, span).head).withSpan(span.focus) :: parents
2393-
}
2394-
2395-
/** If this is a real class, make sure its first parent is a
2388+
/** If this is a real class, make sure its first parent is a
23962389
* constructor call. Cannot simply use a type. Overridden in ReTyper.
23972390
*/
2398-
def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = {
2399-
val firstParent :: otherParents = parents
2400-
if (firstParent.isType && !cls.is(Trait) && !cls.is(JavaDefined))
2401-
typed(untpd.New(untpd.TypedSplice(firstParent), Nil)) :: otherParents
2402-
else parents
2403-
}
2391+
def ensureConstrCall(cls: ClassSymbol, parent: Tree)(using Context): Tree =
2392+
if (parent.isType && !cls.is(Trait) && !cls.is(JavaDefined))
2393+
typed(untpd.New(untpd.TypedSplice(parent), Nil))
2394+
else
2395+
parent
24042396

24052397
def localDummy(cls: ClassSymbol, impl: untpd.Template)(using Context): Symbol =
24062398
newLocalDummy(cls, impl.span)

docs/docs/reference/other-new-features/trait-parameters.md

+30
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,36 @@ The correct way to write `E` is to extend both `Greeting` and
5252
class E extends Greeting("Bob"), FormalGreeting
5353
```
5454

55+
### Traits With Context Parameters
56+
57+
This "explicit extension required" rule is relaxed if the missing trait contains only
58+
[context parameters](../contextual/using-clauses). In that case the trait reference is
59+
implicitly inserted as an additional parent with inferred arguments. For instance,
60+
here's a variant of greetings where the addressee is a context parameter of type
61+
`ImpliedName`:
62+
63+
```scala
64+
case class ImpliedName(name: String):
65+
override def toString = name
66+
67+
trait ImpliedGreeting(using val iname: ImpliedName):
68+
def msg = s"How are you, $iname"
69+
70+
trait ImpliedFormalGreeting extends ImpliedGreeting:
71+
override def msg = s"How do you do, $iname"
72+
73+
class F(using iname: ImpliedName) extends ImpliedFormalGreeting
74+
```
75+
76+
The definition of `F` in the last line is implicitly expanded to
77+
```scala
78+
class F(using iname: ImpliedName) extends
79+
Object,
80+
ImpliedGreeting(using iname),
81+
ImpliedFormalGreeting(using iname)
82+
```
83+
Note the inserted reference to the super trait `ImpliedGreeting`, which was not mentioned explicitly.
84+
5585
## Reference
5686

5787
For more information, see [Scala SIP 25](http://docs.scala-lang.org/sips/pending/trait-parameters.html).

tests/neg/i6060.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class I1(i2: Int) {
22
def apply(i3: Int) = 1
3-
new I1(1)(2) {} // error: too many arguments in parent constructor
3+
new I1(1)(2) {} // error: too many arguments in parent constructor // error
44
}
55

66
class I0(i1: Int) {

tests/neg/i7613.check

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Error: tests/neg/i7613.scala:10:16 ----------------------------------------------------------------------------------
2+
10 | new BazLaws[A] {} // error // error
3+
| ^
4+
| no implicit argument of type Baz[A] was found for parameter x$1 of constructor BazLaws in trait BazLaws
5+
-- Error: tests/neg/i7613.scala:10:2 -----------------------------------------------------------------------------------
6+
10 | new BazLaws[A] {} // error // error
7+
| ^
8+
| no implicit argument of type Bar[A] was found for parameter x$1 of constructor BarLaws in trait BarLaws

tests/neg/i7613.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
trait Foo[A]
2+
trait Bar[A] extends Foo[A]
3+
trait Baz[A] extends Bar[A]
4+
5+
trait FooLaws[A](using Foo[A])
6+
trait BarLaws[A](using Bar[A]) extends FooLaws[A]
7+
trait BazLaws[A](using Baz[A]) extends BarLaws[A]
8+
9+
def instance[A](using Foo[A]): BazLaws[A] =
10+
new BazLaws[A] {} // error // error
11+

tests/pos/reference/trait-parameters.scala

+9
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,13 @@ class E extends Greeting("Bob") with FormalGreeting
1616

1717
// class D2 extends C with Greeting("Bill") // error
1818

19+
case class ImpliedName(name: String):
20+
override def toString = name
1921

22+
trait ImpliedGreeting(using val iname: ImpliedName):
23+
def msg = s"How are you, $iname"
24+
25+
trait ImpliedFormalGreeting extends ImpliedGreeting:
26+
override def msg = s"How do you do, $iname"
27+
28+
class F(using iname: ImpliedName) extends ImpliedFormalGreeting

tests/run/i7613.check

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
D: B1
2+
superD: B1
3+
E: B2
4+
F: B1
5+
F: B2

tests/run/i7613.scala

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
trait Foo[A]
2+
trait Bar[A] extends Foo[A]
3+
trait Baz[A] extends Bar[A]
4+
5+
trait FooLaws[A](using Foo[A])
6+
trait BarLaws[A](using Bar[A]) extends FooLaws[A]
7+
trait BazLaws[A](using Baz[A]) extends BarLaws[A]
8+
9+
def instance1[A](using Baz[A]): BazLaws[A] =
10+
new FooLaws[A] with BarLaws[A] with BazLaws[A] {}
11+
12+
def instance2[A](using Baz[A]): BazLaws[A] =
13+
new BazLaws[A] {}
14+
15+
trait I:
16+
def show(x: String): Unit
17+
class A
18+
trait B1(using I) extends A { summon[I].show("B1") }
19+
trait B2(using I) extends B1 { summon[I].show("B2") }
20+
trait C1 extends B1
21+
trait C2 extends B2
22+
class D(using I) extends A, C1
23+
class E(using I) extends D(using new I { def show(x: String) = println(s"superD: $x")}), C2
24+
class F(using I) extends A, C2
25+
26+
@main def Test =
27+
D(using new I { def show(x: String) = println(s"D: $x")})
28+
E(using new I { def show(x: String) = println(s"E: $x")})
29+
F(using new I { def show(x: String) = println(s"F: $x")})

0 commit comments

Comments
 (0)