Skip to content

Commit 70c85a6

Browse files
Support erased arguments in splicer
1 parent bbc0df7 commit 70c85a6

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,28 @@ object Splicer {
321321
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
322322
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
323323

324+
private final def removeEraisedArguments(args: List[Tree], fnTpe: Type): List[Tree] = {
325+
var result = args
326+
var index = 0
327+
def loop(tp: Type): Unit = tp match {
328+
case tp: TermRef => loop(tp.underlying)
329+
case tp: PolyType => loop(tp.resType)
330+
case tp: MethodType if tp.isErasedMethod =>
331+
tp.paramInfos.foreach { _ =>
332+
result = result.updated(index, null)
333+
index += 1
334+
}
335+
loop(tp.resType)
336+
case tp: MethodType =>
337+
index += tp.paramInfos.size
338+
loop(tp.resType)
339+
case _ => ()
340+
}
341+
loop(fnTpe)
342+
assert(index == args.size)
343+
result.filterNot(null.eq)
344+
}
345+
324346
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
325347
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
326348
val quoted1 = quoted match {
@@ -351,10 +373,12 @@ object Splicer {
351373
interpretModuleAccess(fn.symbol)
352374
} else if (fn.symbol.isStatic) {
353375
val module = fn.symbol.owner
354-
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
376+
def interpretedArgs = removeEraisedArguments(args, fn.tpe).map(arg => interpretTree(arg))
377+
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
355378
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
356379
val module = fn.qualifier.symbol.moduleClass
357-
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
380+
def interpretedArgs = removeEraisedArguments(args, fn.tpe).map(arg => interpretTree(arg))
381+
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
358382
} else if (env.contains(fn.name)) {
359383
env(fn.name)
360384
} else if (tree.symbol.is(InlineProxy)) {
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.quoted._
2+
3+
object Macro {
4+
inline def foo1(i: Int) = $ { case1('{ i }) }
5+
inline def foo2(i: Int) = $ { case2(1)('{ i }) }
6+
inline def foo3(i: Int) = $ { case3('{ i })(1) }
7+
inline def foo4(i: Int) = $ { case4(1)('{ i }, '{ i }) }
8+
inline def foo5(i: Int) = $ { case5('{ i }, '{ i })(1) }
9+
inline def foo6(i: Int) = $ { case6(1)('{ i })('{ i }) }
10+
inline def foo7(i: Int) = $ { case7('{ i })(1)('{ i }) }
11+
inline def foo8(i: Int) = $ { case8('{ i })('{ i })(1) }
12+
13+
def case1 erased (i: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
14+
def case2 (i: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
15+
def case3 erased (i: Expr[Int]) (j: Int) given (QuoteContext): Expr[Int] = '{ 0 }
16+
def case4 (h: Int) erased (i: Expr[Int], j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
17+
def case5 erased (i: Expr[Int], j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 }
18+
def case6 (h: Int) erased (i: Expr[Int]) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
19+
def case7 erased (i: Expr[Int]) (h: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
20+
def case8 erased (i: Expr[Int]) erased (j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 }
21+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Test {
2+
assert(Macro.foo1(1) == 0)
3+
assert(Macro.foo2(1) == 0)
4+
assert(Macro.foo3(1) == 0)
5+
assert(Macro.foo4(1) == 0)
6+
assert(Macro.foo5(1) == 0)
7+
assert(Macro.foo6(1) == 0)
8+
assert(Macro.foo7(1) == 0)
9+
assert(Macro.foo8(1) == 0)
10+
}

0 commit comments

Comments
 (0)