Skip to content

Commit 3c7011b

Browse files
nicolasstuckiNicolas Stucki
authored and
Nicolas Stucki
committed
Use tree checker for macro expanded trees
Trees are only checked if -Xcheck-macros is enabled. Fixes: - Add missing positions to {ValDef,Bind}.apply - Inline by-name ascribed param - Unbound type variables after implicit search
1 parent 2920a4f commit 3c7011b

File tree

8 files changed

+75
-8
lines changed

8 files changed

+75
-8
lines changed

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class Inliner(val call: tpd.Tree)(using Context):
227227
val binding = {
228228
var newArg = arg.changeOwner(ctx.owner, boundSym)
229229
if bindingFlags.is(Inline) && argIsBottom then
230-
newArg = Typed(newArg, TypeTree(formal)) // type ascribe RHS to avoid type errors in expansion. See i8612.scala
230+
newArg = Typed(newArg, TypeTree(formal.widenExpr)) // type ascribe RHS to avoid type errors in expansion. See i8612.scala
231231
if isByName then DefDef(boundSym, newArg)
232232
else ValDef(boundSym, newArg)
233233
}.withSpan(boundSym.span)
@@ -816,6 +816,7 @@ class Inliner(val call: tpd.Tree)(using Context):
816816
&& StagingContext.level == 0
817817
&& !hasInliningErrors =>
818818
val expanded = expandMacro(res.args.head, tree.srcPos)
819+
transform.TreeChecker.checkMacroGeneratedTree(res, expanded)
819820
typedExpr(expanded) // Inline calls and constant fold code generated by the macro
820821
case res =>
821822
specializeEq(inlineIfNeeded(res, pt, locked))

compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ class MacroAnnotations(thisPhase: DenotTransformer):
8282
case (prefixed, newTree :: suffixed) =>
8383
allTrees ++= prefixed
8484
insertedAfter = suffixed :: insertedAfter
85-
prefixed.foreach(checkAndEnter(_, tree.symbol, annot))
86-
suffixed.foreach(checkAndEnter(_, tree.symbol, annot))
85+
prefixed.foreach(checkAndEnter(_, tree, annot))
86+
suffixed.foreach(checkAndEnter(_, tree, annot))
87+
transform.TreeChecker.checkMacroGeneratedTree(tree, newTree)
8788
newTree
8889
case (Nil, Nil) =>
8990
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
@@ -119,8 +120,10 @@ class MacroAnnotations(thisPhase: DenotTransformer):
119120
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])
120121

121122
/** Check that this tree can be added by the macro annotation and enter it if needed */
122-
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
123+
private def checkAndEnter(newTree: Tree, annotatedTree: Tree, annot: Annotation)(using Context) =
124+
transform.TreeChecker.checkMacroGeneratedTree(annotatedTree, newTree)
123125
val sym = newTree.symbol
126+
val annotated = annotatedTree.symbol
124127
if sym.isClass then
125128
report.error(i"macro annotation returning a `class` is not yet supported. $annot tried to add $sym", annot.tree)
126129
else if sym.isType then

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ object TreeChecker {
219219
class Checker(phasesToCheck: Seq[Phase]) extends ReTyper with Checking {
220220
import ast.tpd._
221221

222-
private val nowDefinedSyms = util.HashSet[Symbol]()
222+
protected val nowDefinedSyms = util.HashSet[Symbol]()
223223
private val patBoundSyms = util.HashSet[Symbol]()
224224
private val everDefinedSyms = MutableSymbolMap[untpd.Tree]()
225225

@@ -724,4 +724,42 @@ object TreeChecker {
724724

725725
override def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = tree
726726
}
727+
728+
/** Tree checker that can be applied to a local tree. */
729+
class LocalChecker(phasesToCheck: Seq[Phase]) extends Checker(phasesToCheck: Seq[Phase]):
730+
override def assertDefined(tree: untpd.Tree)(using Context): Unit =
731+
// Only check definitions nested in the local tree
732+
if nowDefinedSyms.contains(tree.symbol.maybeOwner) then
733+
super.assertDefined(tree)
734+
735+
def checkMacroGeneratedTree(original: tpd.Tree, expansion: tpd.Tree)(using Context): Unit =
736+
if ctx.settings.XcheckMacros.value then
737+
val checkingCtx = ctx
738+
.fresh
739+
.addMode(Mode.ImplicitsEnabled)
740+
.setReporter(new ThrowingReporter(ctx.reporter))
741+
742+
val treeChecker =
743+
new LocalChecker(Nil) // TODO enable previous phase post-conditions
744+
745+
try treeChecker.typed(expansion)(using checkingCtx)
746+
catch
747+
case err: java.lang.AssertionError =>
748+
report.error(
749+
s"""Malformed tree was found while expanding macro with -Xcheck-macros.
750+
|The tree does not conform to the compiler's tree invariants.
751+
|
752+
|Macro was:
753+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
754+
|
755+
|The macro returned:
756+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
757+
|
758+
|Error:
759+
|${err.getMessage}
760+
|
761+
|""",
762+
original
763+
)
764+
727765
}

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
298298

299299
object ValDef extends ValDefModule:
300300
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
301-
tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree))
301+
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
302302
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
303303
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
304304
def unapply(vdef: ValDef): (String, TypeTree, Option[Term]) =
@@ -1474,7 +1474,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
14741474

14751475
object Bind extends BindModule:
14761476
def apply(sym: Symbol, pattern: Tree): Bind =
1477-
tpd.Bind(sym, pattern)
1477+
withDefaultPos(tpd.Bind(sym, pattern))
14781478
def copy(original: Tree)(name: String, pattern: Tree): Bind =
14791479
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))
14801480
def unapply(pattern: Bind): (String, Tree) =
@@ -2395,7 +2395,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
23952395

23962396
object Implicits extends ImplicitsModule:
23972397
def search(tpe: TypeRepr): ImplicitSearchResult =
2398-
ctx.typer.inferImplicitArg(tpe, Position.ofMacroExpansion.span)
2398+
val implicitTree = ctx.typer.inferImplicitArg(tpe, Position.ofMacroExpansion.span)
2399+
import tpd.TreeOps
2400+
implicitTree.foreachSubTree(tree => dotc.typer.Inferencing.fullyDefinedType(tree.tpe, "", tree))
2401+
implicitTree
23992402
end Implicits
24002403

24012404
type ImplicitSearchResult = Tree
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.compiletime.{erasedValue, summonFrom}
2+
3+
import scala.quoted._
4+
5+
inline given summonAfterTypeMatch[T]: Any =
6+
${ summonAfterTypeMatchExpr[T] }
7+
8+
private def summonAfterTypeMatchExpr[T: Type](using Quotes): Expr[Any] =
9+
Expr.summon[Foo[T]].get
10+
11+
trait Foo[T]
12+
13+
given IntFoo[T <: Int]: Foo[T] = ???
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
def test: Unit = summonAfterTypeMatch[Int]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import scala.quoted.*
2+
3+
inline def f[T](inline code: =>T): Any =
4+
${ create[T]('{ () => code }) }
5+
6+
def create[T: Type](code: Expr[() => T])(using Quotes): Expr[Any] =
7+
'{ identity($code) }
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
def test: Unit = f[Unit](???)

0 commit comments

Comments
 (0)