Skip to content

Commit 0a603dd

Browse files
committed
Allow cross-stage persistence for types
... provided they are backed by type tags, i.e. implicit values of type `quoted.Type`.
1 parent dc6b2c9 commit 0a603dd

File tree

3 files changed

+220
-84
lines changed

3 files changed

+220
-84
lines changed

compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala

Lines changed: 201 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@ import core._
55
import Decorators._, Flags._, Types._, Contexts._, Symbols._, Constants._
66
import Flags._
77
import ast.Trees._
8+
import util.Positions._
9+
import StdNames._
10+
import ast.untpd
811
import MegaPhase.MiniPhase
12+
import typer.Implicits._
13+
import NameKinds.OuterSelectName
914
import scala.collection.mutable
1015

1116
/** Translates quoted terms and types to `unpickle` method calls.
1217
* Checks that the phase consistency principle (PCP) holds.
1318
*/
14-
class ReifyQuotes extends MacroTransform {
19+
class ReifyQuotes extends MacroTransformWithImplicits {
1520
import ast.tpd._
1621

1722
override def phaseName: String = "reifyQuotes"
@@ -21,9 +26,9 @@ class ReifyQuotes extends MacroTransform {
2126

2227
protected def newTransformer(implicit ctx: Context): Transformer = new Reifier
2328

24-
/** is tree splice operation? */
25-
def isSplice(tree: Select)(implicit ctx: Context) =
26-
tree.symbol == defn.QuotedExpr_~ || tree.symbol == defn.QuotedType_~
29+
/** Is symbol a splice operation? */
30+
def isSplice(sym: Symbol)(implicit ctx: Context) =
31+
sym == defn.QuotedExpr_~ || sym == defn.QuotedType_~
2732

2833
/** Serialize `tree`. Embedded splices are represented as nodes of the form
2934
*
@@ -36,13 +41,57 @@ class ReifyQuotes extends MacroTransform {
3641
def pickleTree(tree: Tree, isType: Boolean)(implicit ctx: Context): String =
3742
tree.show // TODO: replace with TASTY
3843

39-
private class Reifier extends Transformer {
44+
private class Reifier extends ImplicitsTransformer {
45+
46+
/** A class for collecting the splices of some quoted expression */
47+
private class Splices {
48+
49+
/** A listbuffer collecting splices */
50+
val buf = new mutable.ListBuffer[Tree]
51+
52+
/** A map from type ref T to "expression of type quoted.Type[T]".
53+
* These will be turned into splices using `addTags`
54+
*/
55+
val typeTagOfRef = new mutable.LinkedHashMap[TypeRef, Tree]()
56+
57+
/** Assuming typeTagOfRef = `Type1 -> tag1, ..., TypeN -> tagN`, the expression
58+
*
59+
* { type <Type1> = <tag1>.unary_~
60+
* ...
61+
* type <TypeN> = <tagN>.unary.~
62+
* <expr>
63+
* }
64+
*
65+
* where all references to `TypeI` in `expr` are rewired to point to the locally
66+
* defined versions. As a side effect, append the expressions `tag1, ..., `tagN`
67+
* as splices to `buf`.
68+
*/
69+
def addTags(expr: Tree)(implicit ctx: Context): Tree =
70+
if (typeTagOfRef.isEmpty) expr
71+
else {
72+
val assocs = typeTagOfRef.toList
73+
val typeDefs = for ((tp, tag) <- assocs) yield {
74+
val original = tp.symbol.asType
75+
val rhs = tag.select(tpnme.UNARY_~)
76+
val alias = ctx.typeAssigner.assignType(untpd.TypeBoundsTree(rhs, rhs), rhs, rhs)
77+
val local = original.copy(
78+
owner = ctx.owner,
79+
flags = Synthetic,
80+
info = TypeAlias(tag.tpe.select(tpnme.UNARY_~)))
81+
ctx.typeAssigner.assignType(untpd.TypeDef(original.name, alias), local)
82+
}
83+
val (trefs, tags) = assocs.unzip
84+
tags ++=: buf
85+
typeTagOfRef.clear()
86+
Block(typeDefs, expr).subst(trefs.map(_.symbol), typeDefs.map(_.symbol))
87+
}
88+
}
4089

4190
/** The current staging level */
4291
private var currentLevel = 0
4392

4493
/** The splices encountered so far, indexed by staging level */
45-
private val splicesAtLevel = mutable.ArrayBuffer(new mutable.ListBuffer[Tree])
94+
private val splicesAtLevel = mutable.ArrayBuffer(new Splices)
4695

4796
// Invariant: -1 <= currentLevel <= splicesAtLevel.length
4897

@@ -54,34 +103,109 @@ class ReifyQuotes extends MacroTransform {
54103

55104
/** Enter staging level of symbol defined by `tree`, if applicable. */
56105
def markDef(tree: Tree)(implicit ctx: Context) = tree match {
57-
case tree: MemberDef if !levelOf.contains(tree.symbol) =>
106+
case tree: DefTree if !levelOf.contains(tree.symbol) =>
58107
levelOf(tree.symbol) = currentLevel
59108
enteredSyms = tree.symbol :: enteredSyms
60109
case _ =>
61110
}
62111

63-
/** If reference is to a locally defined symbol, check that its staging level
64-
* matches the current level.
112+
/** If `tree` refers to a locally defined symbol (either directly, or in a pickled type),
113+
* check that its staging level matches the current level. References to types
114+
* that are phase-incorrect can still be healed as follows.
115+
*
116+
* If `T` is a reference to a type at the wrong level, and there is an implicit value `tag`
117+
* of type `quoted.Type[T]`, transform `tag` yielding `tag1` and add the binding `T -> tag1`
118+
* to the `typeTagOfRef` map of the current `Splices` structure. These entries will be turned
119+
* info additional type definitions in method `addTags`.
65120
*/
66-
def checkLevel(tree: Tree)(implicit ctx: Context): Unit = {
67-
68-
def check(sym: Symbol, show: Symbol => String): Unit =
69-
if (!sym.isStaticOwner &&
70-
!ctx.owner.ownersIterator.exists(_.isInlineMethod) &&
71-
levelOf.getOrElse(sym, currentLevel) != currentLevel)
72-
ctx.error(em"""access to ${show(sym)} from wrong staging level:
73-
| - the definition is at level ${levelOf(sym)},
74-
| - but the access is at level $currentLevel.""", tree.pos)
75-
76-
def showThis(sym: Symbol) =
77-
if (sym.is(ModuleClass)) sym.sourceModule.show
78-
else i"${sym.name}.this"
79-
80-
val sym = tree.symbol
81-
if (sym.exists)
82-
if (tree.isInstanceOf[This]) check(sym, showThis)
83-
else if (sym.owner.isType) check(sym.owner, showThis)
84-
else check(sym, _.show)
121+
private def checkLevel(tree: Tree)(implicit ctx: Context): Tree = {
122+
123+
/** Try to heal phase-inconsistent reference to type `T` using a local type definition.
124+
* @return None if successful
125+
* @return Some(msg) if unsuccessful where `msg` is a potentially empty error message
126+
* to be added to the "inconsistent phase" message.
127+
*/
128+
def heal(tp: Type): Option[String] = tp match {
129+
case tp: TypeRef =>
130+
val reqType = defn.QuotedTypeType.appliedTo(tp)
131+
val tag = ctx.typer.inferImplicitArg(reqType, tree.pos)
132+
tag.tpe match {
133+
case fail: SearchFailureType =>
134+
Some(i"""
135+
|
136+
| The access would be accepted with the right type tag, but
137+
| ${ctx.typer.missingArgMsg(tag, reqType, "")}""")
138+
case _ =>
139+
splicesAtLevel(currentLevel).typeTagOfRef(tp) = {
140+
currentLevel -= 1
141+
try transform(tag) finally currentLevel += 1
142+
}
143+
None
144+
}
145+
case _ =>
146+
Some("")
147+
}
148+
149+
/** Check reference to `sym` for phase consistency, where `tp` is the underlying type
150+
* by which we refer to `sym`.
151+
*/
152+
def check(sym: Symbol, tp: Type): Unit = {
153+
val isThis = tp.isInstanceOf[ThisType]
154+
def symStr =
155+
if (!isThis) sym.show
156+
else if (sym.is(ModuleClass)) sym.sourceModule.show
157+
else i"${sym.name}.this"
158+
if (!isThis && sym.maybeOwner.isType)
159+
check(sym.owner, sym.owner.thisType)
160+
else if (sym.exists && !sym.isStaticOwner &&
161+
!ctx.owner.ownersIterator.exists(_.isInlineMethod) &&
162+
levelOf.getOrElse(sym, currentLevel) != currentLevel)
163+
heal(tp) match {
164+
case Some(errMsg) =>
165+
ctx.error(em"""access to $symStr from wrong staging level:
166+
| - the definition is at level ${levelOf(sym)},
167+
| - but the access is at level $currentLevel.$errMsg""", tree.pos)
168+
case None =>
169+
}
170+
}
171+
172+
/** Check all named types and this types in a given type for phase consistency */
173+
object checkType extends TypeAccumulator[Unit] {
174+
/** Check that all NamedType and ThisType parts of `tp` are level-correct.
175+
* If they are not, try to heal with a local binding to a typetag splice
176+
*/
177+
def apply(tp: Type): Unit = apply((), tp)
178+
def apply(acc: Unit, tp: Type): Unit = reporting.trace(i"check type level $tp at $currentLevel") {
179+
tp match {
180+
case tp: NamedType if isSplice(tp.symbol) =>
181+
currentLevel -= 1
182+
try foldOver(acc, tp) finally currentLevel += 1
183+
case tp: NamedType =>
184+
check(tp.symbol, tp)
185+
foldOver(acc, tp)
186+
case tp: ThisType =>
187+
check(tp.cls, tp)
188+
foldOver(acc, tp)
189+
case _ =>
190+
foldOver(acc, tp)
191+
}
192+
}
193+
}
194+
195+
tree match {
196+
case (_: Ident) | (_: This) =>
197+
check(tree.symbol, tree.tpe)
198+
case (_: UnApply) | (_: TypeTree) =>
199+
checkType(tree.tpe)
200+
case Select(qual, OuterSelectName(_, levels)) =>
201+
checkType(tree.tpe.widen)
202+
case _: Bind =>
203+
checkType(tree.symbol.info)
204+
case _: Template =>
205+
checkType(tree.symbol.owner.asClass.givenSelfType)
206+
case _ =>
207+
}
208+
tree
85209
}
86210

87211
/** Turn `body` of quote into a call of `scala.quoted.Unpickler.unpickleType` or
@@ -91,66 +215,64 @@ class ReifyQuotes extends MacroTransform {
91215
* - the serialized `body`, as returned from `pickleTree`
92216
* - all splices found in `body`
93217
*/
94-
private def reifyCall(body: Tree, isType: Boolean)(implicit ctx: Context) =
95-
ref(if (isType) defn.Unpickler_unpickleType else defn.Unpickler_unpickleExpr)
96-
.appliedToType(if (isType) body.tpe else body.tpe.widen)
97-
.appliedTo(
98-
Literal(Constant(pickleTree(body, isType))),
99-
SeqLiteral(splicesAtLevel(currentLevel).toList, TypeTree(defn.QuotedType)))
100-
101-
/** Perform operation `op` in quoted context */
102-
private def inQuote(op: => Tree)(implicit ctx: Context) = {
218+
private def reify(body: Tree, isType: Boolean)(implicit ctx: Context) = {
103219
currentLevel += 1
104220
if (currentLevel == splicesAtLevel.length) splicesAtLevel += null
221+
val splices = new Splices
105222
val savedSplices = splicesAtLevel(currentLevel)
106-
splicesAtLevel(currentLevel) = new mutable.ListBuffer[Tree]
107-
try op
223+
splicesAtLevel(currentLevel) = splices
224+
try {
225+
val body1 = splices.addTags(transform(body))
226+
ref(if (isType) defn.Unpickler_unpickleType else defn.Unpickler_unpickleExpr)
227+
.appliedToType(if (isType) body1.tpe else body1.tpe.widen)
228+
.appliedTo(
229+
Literal(Constant(pickleTree(body1, isType))),
230+
SeqLiteral(splices.buf.toList, TypeTree(defn.QuotedType)))
231+
}
108232
finally {
109233
splicesAtLevel(currentLevel) = savedSplices
110234
currentLevel -= 1
111235
}
112236
}
113237

114-
override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
115-
case Apply(fn, arg :: Nil) if fn.symbol == defn.quoteMethod =>
116-
inQuote(reifyCall(transform(arg), isType = false))
117-
case TypeApply(fn, arg :: Nil) if fn.symbol == defn.typeQuoteMethod =>
118-
inQuote(reifyCall(transform(arg), isType = true))
119-
case tree @ Select(body, name) if isSplice(tree) =>
120-
currentLevel -= 1
121-
val body1 = try transform(body) finally currentLevel += 1
122-
if (currentLevel > 0) {
123-
splicesAtLevel(currentLevel) += body1
124-
tree
125-
}
126-
else {
127-
if (currentLevel < 0)
128-
ctx.error(i"splice ~ not allowed under toplevel splice", tree.pos)
129-
cpy.Select(tree)(body1, name)
238+
override def transform(tree: Tree)(implicit ctx: Context): Tree =
239+
reporting.trace(i"reify $tree at $currentLevel", show = true) {
240+
tree match {
241+
case Apply(fn, arg :: Nil) if fn.symbol == defn.quoteMethod =>
242+
reify(arg, isType = false)
243+
case TypeApply(fn, arg :: Nil) if fn.symbol == defn.typeQuoteMethod =>
244+
reify(arg, isType = true)
245+
case tree @ Select(body, name) if isSplice(tree.symbol) =>
246+
currentLevel -= 1
247+
val body1 = try transform(body) finally currentLevel += 1
248+
if (currentLevel > 0) {
249+
splicesAtLevel(currentLevel).buf += body1
250+
tree
251+
}
252+
else {
253+
if (currentLevel < 0)
254+
ctx.error(i"splice ~ not allowed under toplevel splice", tree.pos)
255+
cpy.Select(tree)(body1, name)
256+
}
257+
case Block(stats, _) =>
258+
val last = enteredSyms
259+
stats.foreach(markDef)
260+
try super.transform(tree)
261+
finally
262+
while (enteredSyms ne last) {
263+
levelOf -= enteredSyms.head
264+
enteredSyms = enteredSyms.tail
265+
}
266+
case Inlined(call, bindings, expansion @ Select(body, name)) if isSplice(expansion.symbol) =>
267+
// To maintain phase consistency, convert inlined expressions of the form
268+
// `{ bindings; ~expansion }` to `~{ bindings; expansion }`
269+
cpy.Select(expansion)(cpy.Inlined(tree)(call, bindings, body), name)
270+
case _: Import =>
271+
tree
272+
case _ =>
273+
markDef(tree)
274+
checkLevel(super.transform(tree))
130275
}
131-
case (_: Ident) | (_: This) =>
132-
checkLevel(tree)
133-
super.transform(tree)
134-
case _: MemberDef =>
135-
markDef(tree)
136-
super.transform(tree)
137-
case Block(stats, _) =>
138-
val last = enteredSyms
139-
stats.foreach(markDef)
140-
try super.transform(tree)
141-
finally
142-
while (enteredSyms ne last) {
143-
levelOf -= enteredSyms.head
144-
enteredSyms = enteredSyms.tail
145-
}
146-
case Inlined(call, bindings, expansion @ Select(body, name)) if isSplice(expansion) =>
147-
// To maintain phase consistency, convert inlined expressions of the form
148-
// `{ bindings; ~expansion }` to `~{ bindings; expansion }`
149-
cpy.Select(expansion)(cpy.Inlined(tree)(call, bindings, body), name)
150-
case _: Import =>
151-
tree
152-
case _ =>
153-
super.transform(tree)
154-
}
276+
}
155277
}
156278
}

tests/neg/quoteTest.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ class Test {
1313
'(x + 1) // error: wrong staging level
1414

1515
'((y: Expr[Int]) => ~y ) // error: wrong staging level
16+
17+
def f[T](t: Type[T], x: Expr[T]) = '{
18+
val z2 = ~x // error: wrong staging level for type T
19+
}
20+
1621
}

tests/pos/quoteTest.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,23 @@ import scala.quoted._
22

33
object Test {
44

5-
def f[T](t: Type[T], x: Expr[T]) = '{
5+
def f[T](x: Expr[T])(t0: Type[T]) = {
6+
implicit val t: Type[T] = t0
7+
'{
8+
val y: t.unary_~ = x.unary_~
9+
val z = ~x
10+
}
11+
}
12+
13+
def f2[T](x: Expr[T])(implicit t: Type[T]) = '{
614
val y: t.unary_~ = x.unary_~
7-
val z: ~t = ~x
15+
val z = ~x
816
}
917

10-
f('[Int], '(2))
11-
f('[Boolean], '{ true })
18+
f('(2))('[Int])
19+
f('{ true })('[Boolean])
1220

1321
def g(es: Expr[String], t: Type[String]) =
14-
f('[List[~t]], '{ (~es + "!") :: Nil })
22+
f('{ (~es + "!") :: Nil })('[List[~t]])
1523
}
24+

0 commit comments

Comments
 (0)