Skip to content

Commit 1f7b03c

Browse files
authored
Use tree checker in macros (#16570)
Trees are only checked if `-Xcheck-macros` is enabled. Fixes: - Add missing positions to `{ValDef,Bind}.apply` - Inline by-name ascribed param - Unbound type variables after implicit search - Fixes #15779 - Fixes #16636
2 parents 4277d86 + aaf86b3 commit 1f7b03c

File tree

13 files changed

+228
-88
lines changed

13 files changed

+228
-88
lines changed

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class Inliner(val call: tpd.Tree)(using Context):
227227
val binding = {
228228
var newArg = arg.changeOwner(ctx.owner, boundSym)
229229
if bindingFlags.is(Inline) && argIsBottom then
230-
newArg = Typed(newArg, TypeTree(formal)) // type ascribe RHS to avoid type errors in expansion. See i8612.scala
230+
newArg = Typed(newArg, TypeTree(formal.widenExpr)) // type ascribe RHS to avoid type errors in expansion. See i8612.scala
231231
if isByName then DefDef(boundSym, newArg)
232232
else ValDef(boundSym, newArg)
233233
}.withSpan(boundSym.span)
@@ -816,6 +816,7 @@ class Inliner(val call: tpd.Tree)(using Context):
816816
&& StagingContext.level == 0
817817
&& !hasInliningErrors =>
818818
val expanded = expandMacro(res.args.head, tree.srcPos)
819+
transform.TreeChecker.checkMacroGeneratedTree(res, expanded)
819820
typedExpr(expanded) // Inline calls and constant fold code generated by the macro
820821
case res =>
821822
specializeEq(inlineIfNeeded(res, pt, locked))

compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ class MacroAnnotations:
8282
case (prefixed, newTree :: suffixed) =>
8383
allTrees ++= prefixed
8484
insertedAfter = suffixed :: insertedAfter
85-
prefixed.foreach(checkMacroDef(_, tree.symbol, annot))
86-
suffixed.foreach(checkMacroDef(_, tree.symbol, annot))
85+
prefixed.foreach(checkMacroDef(_, tree, annot))
86+
suffixed.foreach(checkMacroDef(_, tree, annot))
87+
transform.TreeChecker.checkMacroGeneratedTree(tree, newTree)
8788
newTree
8889
case (Nil, Nil) =>
8990
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
@@ -119,8 +120,10 @@ class MacroAnnotations:
119120
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])
120121

121122
/** Check that this tree can be added by the macro annotation */
122-
private def checkMacroDef(newTree: DefTree, annotated: Symbol, annot: Annotation)(using Context) =
123+
private def checkMacroDef(newTree: DefTree, annotatedTree: Tree, annot: Annotation)(using Context) =
124+
transform.TreeChecker.checkMacroGeneratedTree(annotatedTree, newTree)
123125
val sym = newTree.symbol
126+
val annotated = annotatedTree.symbol
124127
if sym.isType && !sym.isClass then
125128
report.error(i"macro annotation cannot return a `type`. $annot tried to add $sym", annot.tree)
126129
else if sym.owner != annotated.owner && !(annotated.owner.isPackageObject && (sym.isClass || sym.is(Module)) && sym.owner == annotated.owner.owner) then

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 116 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ class TreeChecker extends Phase with SymTransformer {
4242
private val seenClasses = collection.mutable.HashMap[String, Symbol]()
4343
private val seenModuleVals = collection.mutable.HashMap[String, Symbol]()
4444

45-
def isValidJVMName(name: Name): Boolean = name.toString.forall(isValidJVMChar)
46-
47-
def isValidJVMMethodName(name: Name): Boolean = name.toString.forall(isValidJVMMethodChar)
48-
4945
val NoSuperClassFlags: FlagSet = Trait | Package
5046

5147
def testDuplicate(sym: Symbol, registry: mutable.Map[String, Symbol], typ: String)(using Context): Unit = {
@@ -109,18 +105,6 @@ class TreeChecker extends Phase with SymTransformer {
109105
else if (ctx.phase.prev.isCheckable)
110106
check(ctx.base.allPhases.toIndexedSeq, ctx)
111107

112-
private def previousPhases(phases: List[Phase])(using Context): List[Phase] = phases match {
113-
case (phase: MegaPhase) :: phases1 =>
114-
val subPhases = phase.miniPhases
115-
val previousSubPhases = previousPhases(subPhases.toList)
116-
if (previousSubPhases.length == subPhases.length) previousSubPhases ::: previousPhases(phases1)
117-
else previousSubPhases
118-
case phase :: phases1 if phase ne ctx.phase =>
119-
phase :: previousPhases(phases1)
120-
case _ =>
121-
Nil
122-
}
123-
124108
def check(phasesToRun: Seq[Phase], ctx: Context): Tree = {
125109
val fusedPhase = ctx.phase.prevMega(using ctx)
126110
report.echo(s"checking ${ctx.compilationUnit} after phase ${fusedPhase}")(using ctx)
@@ -134,7 +118,6 @@ class TreeChecker extends Phase with SymTransformer {
134118

135119
val checkingCtx = ctx
136120
.fresh
137-
.addMode(Mode.ImplicitsEnabled)
138121
.setReporter(new ThrowingReporter(ctx.reporter))
139122

140123
val checker = inContext(ctx) {
@@ -150,9 +133,80 @@ class TreeChecker extends Phase with SymTransformer {
150133
}
151134
}
152135

136+
/**
137+
* Checks that `New` nodes are always wrapped inside `Select` nodes.
138+
*/
139+
def assertSelectWrapsNew(tree: Tree)(using Context): Unit =
140+
(new TreeAccumulator[tpd.Tree] {
141+
override def apply(parent: Tree, tree: Tree)(using Context): Tree = {
142+
tree match {
143+
case tree: New if !parent.isInstanceOf[tpd.Select] =>
144+
assert(assertion = false, i"`New` node must be wrapped in a `Select`:\n parent = ${parent.show}\n child = ${tree.show}")
145+
case _: Annotated =>
146+
// Don't check inside annotations, since they're allowed to contain
147+
// somewhat invalid trees.
148+
case _ =>
149+
foldOver(tree, tree) // replace the parent when folding over the children
150+
}
151+
parent // return the old parent so that my siblings see it
152+
}
153+
})(tpd.EmptyTree, tree)
154+
}
155+
156+
object TreeChecker {
157+
/** - Check that TypeParamRefs and MethodParams refer to an enclosing type.
158+
* - Check that all type variables are instantiated.
159+
*/
160+
def checkNoOrphans(tp0: Type, tree: untpd.Tree = untpd.EmptyTree)(using Context): Type = new TypeMap() {
161+
val definedBinders = new java.util.IdentityHashMap[Type, Any]
162+
def apply(tp: Type): Type = {
163+
tp match {
164+
case tp: BindingType =>
165+
definedBinders.put(tp, tp)
166+
mapOver(tp)
167+
definedBinders.remove(tp)
168+
case tp: ParamRef =>
169+
assert(definedBinders.get(tp.binder) != null, s"orphan param: ${tp.show}, hash of binder = ${System.identityHashCode(tp.binder)}, tree = ${tree.show}, type = $tp0")
170+
case tp: TypeVar =>
171+
assert(tp.isInstantiated, s"Uninstantiated type variable: ${tp.show}, tree = ${tree.show}")
172+
apply(tp.underlying)
173+
case _ =>
174+
mapOver(tp)
175+
}
176+
tp
177+
}
178+
}.apply(tp0)
179+
180+
/** Run some additional checks on the nodes of the trees. Specifically:
181+
*
182+
* - TypeTree can only appear in TypeApply args, New, Typed tpt, Closure
183+
* tpt, SeqLiteral elemtpt, ValDef tpt, DefDef tpt, and TypeDef rhs.
184+
*/
185+
object TreeNodeChecker extends untpd.TreeTraverser:
186+
import untpd._
187+
def traverse(tree: Tree)(using Context) = tree match
188+
case t: TypeTree => assert(assertion = false, i"TypeTree not expected: $t")
189+
case t @ TypeApply(fun, _targs) => traverse(fun)
190+
case t @ New(_tpt) =>
191+
case t @ Typed(expr, _tpt) => traverse(expr)
192+
case t @ Closure(env, meth, _tpt) => traverse(env); traverse(meth)
193+
case t @ SeqLiteral(elems, _elemtpt) => traverse(elems)
194+
case t @ ValDef(_, _tpt, _) => traverse(t.rhs)
195+
case t @ DefDef(_, paramss, _tpt, _) => for params <- paramss do traverse(params); traverse(t.rhs)
196+
case t @ TypeDef(_, _rhs) =>
197+
case t @ Template(constr, parents, self, _) => traverse(constr); traverse(parents); traverse(self); traverse(t.body)
198+
case t => traverseChildren(t)
199+
end traverse
200+
201+
private[TreeChecker] def isValidJVMName(name: Name): Boolean = name.toString.forall(isValidJVMChar)
202+
203+
private[TreeChecker] def isValidJVMMethodName(name: Name): Boolean = name.toString.forall(isValidJVMMethodChar)
204+
205+
153206
class Checker(phasesToCheck: Seq[Phase]) extends ReTyper with Checking {
207+
import ast.tpd._
154208

155-
private val nowDefinedSyms = util.HashSet[Symbol]()
209+
protected val nowDefinedSyms = util.HashSet[Symbol]()
156210
private val patBoundSyms = util.HashSet[Symbol]()
157211
private val everDefinedSyms = MutableSymbolMap[untpd.Tree]()
158212

@@ -658,68 +712,50 @@ class TreeChecker extends Phase with SymTransformer {
658712
override def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = tree
659713
}
660714

661-
/**
662-
* Checks that `New` nodes are always wrapped inside `Select` nodes.
663-
*/
664-
def assertSelectWrapsNew(tree: Tree)(using Context): Unit =
665-
(new TreeAccumulator[tpd.Tree] {
666-
override def apply(parent: Tree, tree: Tree)(using Context): Tree = {
667-
tree match {
668-
case tree: New if !parent.isInstanceOf[tpd.Select] =>
669-
assert(assertion = false, i"`New` node must be wrapped in a `Select`:\n parent = ${parent.show}\n child = ${tree.show}")
670-
case _: Annotated =>
671-
// Don't check inside annotations, since they're allowed to contain
672-
// somewhat invalid trees.
673-
case _ =>
674-
foldOver(tree, tree) // replace the parent when folding over the children
675-
}
676-
parent // return the old parent so that my siblings see it
677-
}
678-
})(tpd.EmptyTree, tree)
679-
}
715+
/** Tree checker that can be applied to a local tree. */
716+
class LocalChecker(phasesToCheck: Seq[Phase]) extends Checker(phasesToCheck: Seq[Phase]):
717+
override def assertDefined(tree: untpd.Tree)(using Context): Unit =
718+
// Only check definitions nested in the local tree
719+
if nowDefinedSyms.contains(tree.symbol.maybeOwner) then
720+
super.assertDefined(tree)
680721

681-
object TreeChecker {
682-
/** - Check that TypeParamRefs and MethodParams refer to an enclosing type.
683-
* - Check that all type variables are instantiated.
684-
*/
685-
def checkNoOrphans(tp0: Type, tree: untpd.Tree = untpd.EmptyTree)(using Context): Type = new TypeMap() {
686-
val definedBinders = new java.util.IdentityHashMap[Type, Any]
687-
def apply(tp: Type): Type = {
688-
tp match {
689-
case tp: BindingType =>
690-
definedBinders.put(tp, tp)
691-
mapOver(tp)
692-
definedBinders.remove(tp)
693-
case tp: ParamRef =>
694-
assert(definedBinders.get(tp.binder) != null, s"orphan param: ${tp.show}, hash of binder = ${System.identityHashCode(tp.binder)}, tree = ${tree.show}, type = $tp0")
695-
case tp: TypeVar =>
696-
assert(tp.isInstantiated, s"Uninstantiated type variable: ${tp.show}, tree = ${tree.show}")
697-
apply(tp.underlying)
698-
case _ =>
699-
mapOver(tp)
700-
}
701-
tp
702-
}
703-
}.apply(tp0)
722+
def checkMacroGeneratedTree(original: tpd.Tree, expansion: tpd.Tree)(using Context): Unit =
723+
if ctx.settings.XcheckMacros.value then
724+
val checkingCtx = ctx
725+
.fresh
726+
.setReporter(new ThrowingReporter(ctx.reporter))
727+
val phases = ctx.base.allPhases.toList
728+
val treeChecker = new LocalChecker(previousPhases(phases))
729+
730+
try treeChecker.typed(expansion)(using checkingCtx)
731+
catch
732+
case err: java.lang.AssertionError =>
733+
report.error(
734+
s"""Malformed tree was found while expanding macro with -Xcheck-macros.
735+
|The tree does not conform to the compiler's tree invariants.
736+
|
737+
|Macro was:
738+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
739+
|
740+
|The macro returned:
741+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
742+
|
743+
|Error:
744+
|${err.getMessage}
745+
|
746+
|""",
747+
original
748+
)
704749

705-
/** Run some additional checks on the nodes of the trees. Specifically:
706-
*
707-
* - TypeTree can only appear in TypeApply args, New, Typed tpt, Closure
708-
* tpt, SeqLiteral elemtpt, ValDef tpt, DefDef tpt, and TypeDef rhs.
709-
*/
710-
object TreeNodeChecker extends untpd.TreeTraverser:
711-
import untpd._
712-
def traverse(tree: Tree)(using Context) = tree match
713-
case t: TypeTree => assert(assertion = false, i"TypeTree not expected: $t")
714-
case t @ TypeApply(fun, _targs) => traverse(fun)
715-
case t @ New(_tpt) =>
716-
case t @ Typed(expr, _tpt) => traverse(expr)
717-
case t @ Closure(env, meth, _tpt) => traverse(env); traverse(meth)
718-
case t @ SeqLiteral(elems, _elemtpt) => traverse(elems)
719-
case t @ ValDef(_, _tpt, _) => traverse(t.rhs)
720-
case t @ DefDef(_, paramss, _tpt, _) => for params <- paramss do traverse(params); traverse(t.rhs)
721-
case t @ TypeDef(_, _rhs) =>
722-
case t @ Template(constr, parents, self, _) => traverse(constr); traverse(parents); traverse(self); traverse(t.body)
723-
case t => traverseChildren(t)
724-
end traverse
750+
private[TreeChecker] def previousPhases(phases: List[Phase])(using Context): List[Phase] = phases match {
751+
case (phase: MegaPhase) :: phases1 =>
752+
val subPhases = phase.miniPhases
753+
val previousSubPhases = previousPhases(subPhases.toList)
754+
if (previousSubPhases.length == subPhases.length) previousSubPhases ::: previousPhases(phases1)
755+
else previousSubPhases
756+
case phase :: phases1 if phase ne ctx.phase =>
757+
phase :: previousPhases(phases1)
758+
case _ =>
759+
Nil
760+
}
725761
}

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
307307

308308
object ValDef extends ValDefModule:
309309
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
310-
tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree))
310+
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
311311
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
312312
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
313313
def unapply(vdef: ValDef): (String, TypeTree, Option[Term]) =
@@ -1483,7 +1483,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
14831483

14841484
object Bind extends BindModule:
14851485
def apply(sym: Symbol, pattern: Tree): Bind =
1486-
tpd.Bind(sym, pattern)
1486+
withDefaultPos(tpd.Bind(sym, pattern))
14871487
def copy(original: Tree)(name: String, pattern: Tree): Bind =
14881488
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))
14891489
def unapply(pattern: Bind): (String, Tree) =
@@ -2404,7 +2404,13 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
24042404

24052405
object Implicits extends ImplicitsModule:
24062406
def search(tpe: TypeRepr): ImplicitSearchResult =
2407-
ctx.typer.inferImplicitArg(tpe, Position.ofMacroExpansion.span)
2407+
import tpd.TreeOps
2408+
val implicitTree = ctx.typer.inferImplicitArg(tpe, Position.ofMacroExpansion.span)
2409+
// Make sure that we do not have any uninstantiated type variables.
2410+
// See tests/pos-macros/i16636.
2411+
// See tests/pos-macros/exprSummonWithTypeVar with -Xcheck-macros.
2412+
dotc.typer.Inferencing.fullyDefinedType(implicitTree.tpe, "", implicitTree)
2413+
implicitTree
24082414
end Implicits
24092415

24102416
type ImplicitSearchResult = Tree

tests/neg-macros/annot-mod-top-method-add-top-method/Macro_1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ class addTopLevelMethodOutsidePackageObject extends MacroAnnotation:
99
import quotes.reflect._
1010
val methType = MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Int])
1111
val methSym = Symbol.newMethod(Symbol.spliceOwner.owner, Symbol.freshName("toLevelMethod"), methType, Flags.EmptyFlags, Symbol.noSymbol)
12-
val methDef = ValDef(methSym, Some(Literal(IntConstant(1))))
12+
val methDef = DefDef(methSym, _ => Some(Literal(IntConstant(1))))
1313
List(methDef, tree)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.compiletime.{erasedValue, summonFrom}
2+
3+
import scala.quoted._
4+
5+
inline given summonAfterTypeMatch[T]: Any =
6+
${ summonAfterTypeMatchExpr[T] }
7+
8+
private def summonAfterTypeMatchExpr[T: Type](using Quotes): Expr[Any] =
9+
Expr.summon[Foo[T]].get
10+
11+
trait Foo[T]
12+
13+
given IntFoo[T <: Int]: Foo[T] = ???
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
def test: Unit = summonAfterTypeMatch[Int]

tests/pos-macros/i15779/Macro_1.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import scala.quoted._
2+
import scala.deriving.Mirror
3+
4+
trait Encoder[-A]
5+
6+
trait PrimitiveEncoder[A] extends Encoder[A]
7+
8+
given intOpt: PrimitiveEncoder[Option[Int]] with {}
9+
10+
given primitiveNotNull[T](using e: Encoder[Option[T]]): PrimitiveEncoder[T] =
11+
new PrimitiveEncoder[T] {}
12+
13+
transparent inline given fromMirror[A]: Any = ${ fromMirrorImpl[A] }
14+
15+
def fromMirrorImpl[A : Type](using q: Quotes): Expr[Any] =
16+
Expr.summon[Mirror.Of[A]].get match
17+
case '{ ${mirror}: Mirror.ProductOf[A] { type MirroredElemTypes = elementTypes } } =>
18+
val encoder = Type.of[elementTypes] match
19+
case '[tpe *: EmptyTuple] =>
20+
Expr.summon[Encoder[tpe]].get
21+
22+
encoder match
23+
case '{ ${encoder}: Encoder[tpe] } => // ok
24+
case _ => ???
25+
26+
encoder match
27+
case '{ ${encoder}: Encoder[tpe] } => // ok
28+
case _ => ???
29+
30+
encoder

tests/pos-macros/i15779/Test_2.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
case class JustInt(i: Int)
2+
3+
val x = fromMirror[JustInt]

tests/pos-macros/i16636/Macro_1.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import scala.quoted.*
2+
3+
trait ReproTransformer[A, B] {
4+
def transform(from: A): B
5+
}
6+
7+
object ReproTransformer {
8+
final class Identity[A, B >: A] extends ReproTransformer[A, B] {
9+
def transform(from: A): B = from
10+
}
11+
12+
given identity[A, B >: A]: Identity[A, B] = Identity[A, B]
13+
14+
inline def getTransformer[A, B]: ReproTransformer[A, B] = ${ getTransformerMacro[A, B] }
15+
16+
def getTransformerMacro[A, B](using quotes: Quotes, A: Type[A], B: Type[B]) = {
17+
import quotes.reflect.*
18+
19+
val transformer = (A -> B) match {
20+
case '[a] -> '[b] =>
21+
val summoned = Expr.summon[ReproTransformer[a, b]].get
22+
// ----------- INTERESTING STUFF STARTS HERE
23+
summoned match {
24+
case '{ $t: ReproTransformer[src, dest] } => t
25+
}
26+
// ----------- INTERESTING STUFF ENDS HERE
27+
}
28+
transformer.asExprOf[ReproTransformer[A, B]]
29+
}
30+
}

0 commit comments

Comments
 (0)