From 056fbc50396d15c6e9c3a828c9b355d6caf5444a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franc=CC=A7ois=20Monniot?= Date: Thu, 27 Jul 2023 15:38:13 -0700 Subject: [PATCH] Fix pretty printer to handle using and erased modifier --- .../runtime/impl/printers/SourceCode.scala | 47 ++++++++++++------- tests/run-macros/term-show.check | 21 +++++++++ tests/run-macros/term-show/Macro_1.scala | 7 +++ tests/run-macros/term-show/Test_2.scala | 31 ++++++------ 4 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 tests/run-macros/term-show.check diff --git a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala index d347e17975b2..cd36e31716a7 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala @@ -158,7 +158,7 @@ object SourceCode { for paramClause <- paramss do paramClause match case TermParamClause(params) => - printArgsDefs(params) + printMethdArgsDefs(params) case TypeParamClause(params) => printTargsDefs(stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(params)) } @@ -313,7 +313,7 @@ object SourceCode { this += highlightKeyword("def ") += highlightValDef(name1) for clause <- paramss do clause match - case TermParamClause(params) => printArgsDefs(params) + case TermParamClause(params) => printMethdArgsDefs(params) case TypeParamClause(params) => printTargsDefs(params.zip(params)) if (!isConstructor) { this += ": " @@ -460,7 +460,7 @@ object SourceCode { case tree @ Lambda(params, body) => // must come before `Block` inParens { - printArgsDefs(params) + printLambdaArgsDefs(params) this += (if tree.tpe.isContextFunctionType then " ?=> " else " => ") printTree(body) } @@ -804,29 +804,37 @@ object SourceCode { } } - private def printArgsDefs(args: List[ValDef])(using elideThis: Option[Symbol]): Unit = { + private def printSeparatedParamDefs(list: List[ValDef])(using elideThis: Option[Symbol]): Unit = list match { + case Nil => + case x :: Nil => printParamDef(x) + case x :: xs => + printParamDef(x) + this += ", " + printSeparatedParamDefs(xs) + } + + private def printMethdArgsDefs(args: List[ValDef])(using elideThis: Option[Symbol]): Unit = { val argFlags = args match { case Nil => Flags.EmptyFlags case arg :: _ => arg.symbol.flags } - if (argFlags.is(Flags.Erased | Flags.Given)) { - if (argFlags.is(Flags.Given)) this += " given" - if (argFlags.is(Flags.Erased)) this += " erased" - this += " " - } inParens { if (argFlags.is(Flags.Implicit) && !argFlags.is(Flags.Given)) this += "implicit " + if (argFlags.is(Flags.Given)) this += "using " - def printSeparated(list: List[ValDef]): Unit = list match { - case Nil => - case x :: Nil => printParamDef(x) - case x :: xs => - printParamDef(x) - this += ", " - printSeparated(xs) - } + printSeparatedParamDefs(args) + } + } + + private def printLambdaArgsDefs(args: List[ValDef])(using elideThis: Option[Symbol]): Unit = { + val argFlags = args match { + case Nil => Flags.EmptyFlags + case arg :: _ => arg.symbol.flags + } + inParens { + if (argFlags.is(Flags.Implicit) && !argFlags.is(Flags.Given)) this += "implicit " - printSeparated(args) + printSeparatedParamDefs(args) } } @@ -846,6 +854,9 @@ object SourceCode { private def printParamDef(arg: ValDef)(using elideThis: Option[Symbol]): Unit = { val name = splicedName(arg.symbol).getOrElse(arg.symbol.name) val sym = arg.symbol.owner + + if (arg.symbol.flags.is(Flags.Erased)) this += "erased " + if sym.isDefDef && sym.name == "" then val ClassDef(_, _, _, _, body) = sym.owner.tree: @unchecked body.collectFirst { diff --git a/tests/run-macros/term-show.check b/tests/run-macros/term-show.check new file mode 100644 index 000000000000..91ba0308e3db --- /dev/null +++ b/tests/run-macros/term-show.check @@ -0,0 +1,21 @@ +{ + class C() { + def a: scala.Int = 0 + private[this] def b: scala.Int = 0 + private[this] def c: scala.Int = 0 + private[C] def d: scala.Int = 0 + protected def e: scala.Int = 0 + protected[this] def f: scala.Int = 0 + protected[C] def g: scala.Int = 0 + } + () +} +@scala.annotation.internal.SourceFile("tests/run-macros/term-show/Test_2.scala") trait A() extends java.lang.Object { + def imp(x: scala.Int)(implicit str: scala.Predef.String): scala.Int + def use(`x₂`: scala.Int)(using `str₂`: scala.Predef.String): scala.Int + def era(`x₃`: scala.Int)(erased `str₃`: scala.Predef.String): scala.Int + def f1(x1: scala.Int, erased x2: scala.Int): scala.Int + def f2(erased `x1₂`: scala.Int, erased `x2₂`: scala.Int): scala.Int + def f3(using `x1₃`: scala.Int, erased `x2₃`: scala.Int): scala.Int + def f4(using erased `x1₄`: scala.Int, erased `x2₄`: scala.Int): scala.Int +} diff --git a/tests/run-macros/term-show/Macro_1.scala b/tests/run-macros/term-show/Macro_1.scala index 1517652b359a..8e26c715d3ed 100644 --- a/tests/run-macros/term-show/Macro_1.scala +++ b/tests/run-macros/term-show/Macro_1.scala @@ -5,4 +5,11 @@ object TypeToolbox { private def showImpl(using Quotes)(v: Expr[Any]): Expr[String] = import quotes.reflect.* Expr(v.show) + + inline def showTree(inline className: String): String = ${ showTreeImpl('className) } + private def showTreeImpl(className: Expr[String])(using Quotes) : Expr[String] = + import quotes.reflect.* + val name = className.valueOrAbort + val res = Symbol.requiredClass(name).tree.show + Expr(res) } diff --git a/tests/run-macros/term-show/Test_2.scala b/tests/run-macros/term-show/Test_2.scala index 76c51bdefd63..eebd50576930 100644 --- a/tests/run-macros/term-show/Test_2.scala +++ b/tests/run-macros/term-show/Test_2.scala @@ -1,7 +1,19 @@ +import scala.language.experimental.erasedDefinitions + +trait A: + def imp(x: Int)(implicit str: String): Int + def use(x: Int)(using str: String): Int + def era(x: Int)(erased str: String): Int + + def f1(x1: Int, erased x2: Int): Int + def f2(erased x1: Int, erased x2: Int): Int + def f3(using x1: Int, erased x2: Int): Int + def f4(using erased x1: Int, erased x2: Int): Int + object Test { import TypeToolbox.* def main(args: Array[String]): Unit = { - assert(show { + println(show { class C { def a = 0 private def b = 0 @@ -11,19 +23,8 @@ object Test { protected[this] def f = 0 protected[C] def g = 0 } - } - == - """{ - | class C() { - | def a: scala.Int = 0 - | private[this] def b: scala.Int = 0 - | private[this] def c: scala.Int = 0 - | private[C] def d: scala.Int = 0 - | protected def e: scala.Int = 0 - | protected[this] def f: scala.Int = 0 - | protected[C] def g: scala.Int = 0 - | } - | () - |}""".stripMargin) + }) + + println(showTree("A")) } }