Skip to content

Commit d663e36

Browse files
committed
Update implementation of newMain annotation
1 parent 46e98dd commit d663e36

30 files changed

+502
-464
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

+14-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ import NameKinds.DefaultGetterName
1111
import Annotations.Annotation
1212

1313
object MainProxies {
14+
15+
/** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16+
def proxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
17+
mainAnnotationProxies(stats) ++ mainProxies(stats)
18+
}
19+
1420
/** Generate proxy classes for @main functions.
1521
* A function like
1622
*
@@ -29,7 +35,7 @@ object MainProxies {
2935
* catch case err: ParseError => showError(err)
3036
* }
3137
*/
32-
def mainProxiesOld(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
38+
private def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
3339
import tpd._
3440
def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap {
3541
case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) =>
@@ -39,11 +45,11 @@ object MainProxies {
3945
case _ =>
4046
Nil
4147
}
42-
mainMethods(stats).flatMap(mainProxyOld)
48+
mainMethods(stats).flatMap(mainProxy)
4349
}
4450

4551
import untpd._
46-
def mainProxyOld(mainFun: Symbol)(using Context): List[TypeDef] = {
52+
private def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = {
4753
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span
4854
def pos = mainFun.sourcePos
4955
val argsRef = Ident(nme.args)
@@ -165,7 +171,7 @@ object MainProxies {
165171
* }
166172
* }
167173
*/
168-
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
174+
private def mainAnnotationProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
169175
import tpd._
170176

171177
/**
@@ -188,12 +194,12 @@ object MainProxies {
188194
def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap {
189195
case stat: DefDef =>
190196
val sym = stat.symbol
191-
sym.annotations.filter(_.matches(defn.MainAnnot)) match {
197+
sym.annotations.filter(_.matches(defn.MainAnnotationClass)) match {
192198
case Nil =>
193199
Nil
194200
case _ :: Nil =>
195201
val paramAnnotations = stat.paramss.flatMap(_.map(
196-
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotParameterAnnotation))
202+
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation))
197203
))
198204
(sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
199205
case mainAnnot :: others =>
@@ -207,7 +213,7 @@ object MainProxies {
207213
}
208214

209215
// Assuming that the top-level object was already generated, all main methods will have a scope
210-
mainMethods(EmptyTree, stats).flatMap(mainProxy)
216+
mainMethods(EmptyTree, stats).flatMap(mainAnnotationProxy)
211217
}
212218

213219
private def mainAnnotationProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): Option[TypeDef] = {
@@ -359,7 +365,7 @@ object MainProxies {
359365
case tree => super.transform(tree)
360366
}
361367
val annots = mainFun.annotations
362-
.filterNot(_.matches(defn.MainAnnot))
368+
.filterNot(_.matches(defn.MainAnnotationClass))
363369
.map(annot => insertTypeSplices.transform(annot.tree))
364370
val mainMeth = DefDef(nme.main, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body)
365371
.withFlags(JavaStatic)

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,7 @@ class Definitions {
920920
@tu lazy val InlineParamAnnot: ClassSymbol = requiredClass("scala.annotation.internal.InlineParam")
921921
@tu lazy val ErasedParamAnnot: ClassSymbol = requiredClass("scala.annotation.internal.ErasedParam")
922922
@tu lazy val InvariantBetweenAnnot: ClassSymbol = requiredClass("scala.annotation.internal.InvariantBetween")
923+
@tu lazy val MainAnnot: ClassSymbol = requiredClass("scala.main")
923924
@tu lazy val MigrationAnnot: ClassSymbol = requiredClass("scala.annotation.migration")
924925
@tu lazy val NowarnAnnot: ClassSymbol = requiredClass("scala.annotation.nowarn")
925926
@tu lazy val TransparentTraitAnnot: ClassSymbol = requiredClass("scala.annotation.transparentTrait")

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -2627,7 +2627,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
26272627
pkg.moduleClass.info.decls.lookup(topLevelClassName).ensureCompleted()
26282628
var stats1 = typedStats(tree.stats, pkg.moduleClass)._1
26292629
if (!ctx.isAfterTyper)
2630-
stats1 = stats1 ++ typedBlockStats(MainProxies.mainProxies(stats1))._1
2630+
stats1 = stats1 ++ typedBlockStats(MainProxies.proxies(stats1))._1
26312631
cpy.PackageDef(tree)(pid1, stats1).withType(pkg.termRef)
26322632
}
26332633
case _ =>

0 commit comments

Comments
 (0)