Skip to content

Commit 1bc6cb9

Browse files
committed
Make import contexts visible in macro transforms
Todo: We have very similar and involved operations now sor handling statements in tpd.TreeMapWithPreciseStatContexts and MegaPhase. Can we factor out the common logic? But we should not create any closures doing so, to keep things fast.
1 parent b1aee82 commit 1bc6cb9

File tree

3 files changed

+49
-52
lines changed

3 files changed

+49
-52
lines changed

compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,12 @@ import scala.annotation.tailrec
1515
*
1616
* This incudes implicits defined in scope as well as imported implicits.
1717
*/
18-
class TreeMapWithImplicits extends tpd.TreeMap {
18+
class TreeMapWithImplicits extends tpd.TreeMapWithPreciseStatContexts {
1919
import tpd._
2020

2121
def transformSelf(vd: ValDef)(using Context): ValDef =
2222
cpy.ValDef(vd)(tpt = transform(vd.tpt))
2323

24-
/** Transform statements, while maintaining import contexts and expression contexts
25-
* in the same way as Typer does. The code addresses additional concerns:
26-
* - be tail-recursive where possible
27-
* - don't re-allocate trees where nothing has changed
28-
*/
29-
override def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
30-
31-
@tailrec def traverse(curStats: List[Tree])(using Context): List[Tree] = {
32-
33-
def recur(stats: List[Tree], changed: Tree, rest: List[Tree])(using Context): List[Tree] =
34-
if (stats eq curStats) {
35-
val rest1 = transformStats(rest, exprOwner)
36-
changed match {
37-
case Thicket(trees) => trees ::: rest1
38-
case tree => tree :: rest1
39-
}
40-
}
41-
else stats.head :: recur(stats.tail, changed, rest)
42-
43-
curStats match {
44-
case stat :: rest =>
45-
val statCtx = stat match {
46-
case stat: DefTree => ctx
47-
case _ => ctx.exprContext(stat, exprOwner)
48-
}
49-
val restCtx = stat match {
50-
case stat: Import => ctx.importContext(stat, stat.symbol)
51-
case _ => ctx
52-
}
53-
val stat1 = transform(stat)(using statCtx)
54-
if (stat1 ne stat) recur(stats, stat1, rest)(using restCtx)
55-
else traverse(rest)(using restCtx)
56-
case nil =>
57-
stats
58-
}
59-
}
60-
traverse(stats)
61-
}
62-
6324
private def nestedScopeCtx(defs: List[Tree])(using Context): Context = {
6425
val nestedCtx = ctx.fresh.setNewScope
6526
defs foreach {

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,10 +1153,54 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11531153
recur(trees, 0)
11541154
end extension
11551155

1156+
/** A treemap that generates the same contexts as the original typer for statements.
1157+
* This means:
1158+
* - statements that are not definitions get the exprOwner as owner
1159+
* - imports are reflected in the contexts of subsequent statements
1160+
*/
1161+
class TreeMapWithPreciseStatContexts(cpy: TreeCopier = tpd.cpy) extends TreeMap(cpy):
1162+
1163+
/** Transform statements, while maintaining import contexts and expression contexts
1164+
* in the same way as Typer does. The code addresses additional concerns:
1165+
* - be tail-recursive where possible
1166+
* - don't re-allocate trees where nothing has changed
1167+
*/
1168+
override def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
1169+
1170+
@tailrec def traverse(curStats: List[Tree])(using Context): List[Tree] =
1171+
1172+
def recur(stats: List[Tree], changed: Tree, rest: List[Tree])(using Context): List[Tree] =
1173+
if stats eq curStats then
1174+
val rest1 = transformStats(rest, exprOwner)
1175+
changed match
1176+
case Thicket(trees) => trees ::: rest1
1177+
case tree => tree :: rest1
1178+
else
1179+
stats.head :: recur(stats.tail, changed, rest)
1180+
1181+
curStats match
1182+
case stat :: rest =>
1183+
val statCtx = stat match
1184+
case _: DefTree | _: ImportOrExport => ctx
1185+
case _ => ctx.exprContext(stat, exprOwner)
1186+
val restCtx = stat match
1187+
case stat: Import => ctx.importContext(stat, stat.symbol)
1188+
case _ => ctx
1189+
val stat1 = transform(stat)(using statCtx)
1190+
if stat1 ne stat then recur(stats, stat1, rest)(using restCtx)
1191+
else traverse(rest)(using restCtx)
1192+
case nil =>
1193+
stats
1194+
1195+
traverse(stats)
1196+
end transformStats
1197+
1198+
end TreeMapWithPreciseStatContexts
1199+
11561200
/** Map Inlined nodes, NamedArgs, Blocks with no statements and local references to underlying arguments.
11571201
* Also drops Inline and Block with no statements.
11581202
*/
1159-
class MapToUnderlying extends TreeMap {
1203+
private class MapToUnderlying extends TreeMap {
11601204
override def transform(tree: Tree)(using Context): Tree = tree match {
11611205
case tree: Ident if isBinding(tree.symbol) && skipLocal(tree.symbol) =>
11621206
tree.symbol.defTree match {

compiler/src/dotty/tools/dotc/transform/MacroTransform.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,11 @@ abstract class MacroTransform extends Phase {
2929
*/
3030
protected def transformPhase(using Context): Phase = this
3131

32-
class Transformer extends TreeMap(cpy = cpyBetweenPhases) {
32+
class Transformer extends TreeMapWithPreciseStatContexts(cpy = cpyBetweenPhases):
3333

34-
protected def localCtx(tree: Tree)(using Context): FreshContext =
34+
protected def localCtx(tree: Tree)(using Context): FreshContext =
3535
ctx.fresh.setTree(tree).setOwner(localOwner(tree))
3636

37-
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
38-
def transformStat(stat: Tree): Tree = stat match {
39-
case _: Import | _: DefTree => transform(stat)
40-
case _ => transform(stat)(using ctx.exprContext(stat, exprOwner))
41-
}
42-
flatten(trees.mapconserve(transformStat(_)))
43-
}
44-
4537
override def transform(tree: Tree)(using Context): Tree =
4638
try
4739
tree match {
@@ -67,5 +59,5 @@ abstract class MacroTransform extends Phase {
6759

6860
def transformSelf(vd: ValDef)(using Context): ValDef =
6961
cpy.ValDef(vd)(tpt = transform(vd.tpt))
70-
}
62+
end Transformer
7163
}

0 commit comments

Comments
 (0)