Skip to content

Commit d253375

Browse files
authored
Merge pull request #287 from smarter/fix/main-class-detection
Fix #102: Better main class detection
2 parents 970e18e + f10c53c commit d253375

File tree

17 files changed

+140
-18
lines changed

17 files changed

+140
-18
lines changed

internal/compiler-bridge/src/main/scala/xsbt/API.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ final class API(val global: CallbackGlobal) extends Compat with GlobalHelpers {
4646
extractUsedNames.extractAndReport(unit)
4747

4848
val classApis = traverser.allNonLocalClasses
49+
val mainClasses = traverser.mainClasses
4950

5051
classApis.foreach(callback.api(sourceFile, _))
52+
mainClasses.foreach(callback.mainClass(sourceFile, _))
5153
}
5254
}
5355

@@ -56,6 +58,9 @@ final class API(val global: CallbackGlobal) extends Compat with GlobalHelpers {
5658
def allNonLocalClasses: Set[ClassLike] = {
5759
extractApi.allExtractedNonLocalClasses
5860
}
61+
62+
def mainClasses: Set[String] = extractApi.mainClasses
63+
5964
def `class`(c: Symbol): Unit = {
6065
extractApi.extractAllClassesOf(c.owner, c)
6166
}

internal/compiler-bridge/src/main/scala/xsbt/ExtractAPI.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ package xsbt
1010
import java.io.File
1111
import java.util.{ Arrays, Comparator }
1212
import scala.tools.nsc.symtab.Flags
13-
import scala.collection.mutable.{ HashMap, HashSet }
13+
import scala.collection.mutable.{ HashMap, HashSet, ListBuffer }
1414
import xsbti.api._
1515

1616
import scala.tools.nsc.Global
@@ -71,6 +71,7 @@ class ExtractAPI[GlobalType <: Global](
7171
private[this] val emptyStringArray = new Array[String](0)
7272

7373
private[this] val allNonLocalClassesInSrc = new HashSet[xsbti.api.ClassLike]
74+
private[this] val _mainClasses = new HashSet[String]
7475

7576
/**
7677
* Implements a work-around for https://github.com/sbt/sbt/issues/823
@@ -600,6 +601,11 @@ class ExtractAPI[GlobalType <: Global](
600601
allNonLocalClassesInSrc.toSet
601602
}
602603

604+
def mainClasses: Set[String] = {
605+
forceStructures()
606+
_mainClasses.toSet
607+
}
608+
603609
private def classLike(in: Symbol, c: Symbol): ClassLikeDef =
604610
classLikeCache.getOrElseUpdate((in, c), mkClassLike(in, c))
605611
private def mkClassLike(in: Symbol, c: Symbol): ClassLikeDef = {
@@ -641,6 +647,10 @@ class ExtractAPI[GlobalType <: Global](
641647

642648
allNonLocalClassesInSrc += classWithMembers
643649

650+
if (sym.isStatic && defType == DefinitionType.Module && definitions.hasJavaMainMethod(sym)) {
651+
_mainClasses += name
652+
}
653+
644654
val classDef = new xsbti.api.ClassLikeDef(
645655
name,
646656
acc,

internal/compiler-interface/src/main/java/xsbti/AnalysisCallback.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ void generatedNonLocalClass(File source,
110110
*/
111111
void api(File sourceFile, xsbti.api.ClassLike classApi);
112112

113+
/**
114+
* Register a class containing an entry point coming from a given source file.
115+
*
116+
* A class is an entry point if its bytecode contains a method with the
117+
* following signature:
118+
* <pre>
119+
* public static void main(String[] args);
120+
* </pre>
121+
*
122+
* @param sourceFile Source file where <code>className</code> is defined.
123+
* @param className A class containing an entry point.
124+
*/
125+
void mainClass(File sourceFile, String className);
126+
113127
/**
114128
* Register the use of a <code>name</code> from a given source class name.
115129
*
@@ -158,4 +172,4 @@ void problem(String what,
158172
* phase defined by <code>xsbt-analyzer</code> should be added.
159173
*/
160174
boolean enabled();
161-
}
175+
}

internal/compiler-interface/src/main/java/xsbti/compile/analysis/SourceInfo.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,12 @@ public interface SourceInfo {
3030
* @return The compiler reported problems.
3131
*/
3232
public Problem[] getUnreportedProblems();
33+
34+
/**
35+
* Returns the main classes found in this compilation unit.
36+
*
37+
* @return The full name of the main classes, like "foo.bar.Main"
38+
*/
39+
public String[] getMainClasses();
3340
}
3441

internal/compiler-interface/src/test/scala/xsbti/TestCallback.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class TestCallback extends AnalysisCallback {
6464
()
6565
}
6666

67+
def mainClass(source: File, className: String): Unit = ()
68+
6769
override def enabled(): Boolean = true
6870

6971
def problem(category: String,

internal/zinc-apiinfo/src/main/scala/sbt/internal/inc/ClassToAPI.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ object ClassToAPI {
2222
def apply(c: Seq[Class[_]]): Seq[api.ClassLike] = process(c)._1
2323

2424
// (api, public inherited classes)
25-
def process(classes: Seq[Class[_]]): (Seq[api.ClassLike], Set[(Class[_], Class[_])]) = {
25+
def process(
26+
classes: Seq[Class[_]]): (Seq[api.ClassLike], Seq[String], Set[(Class[_], Class[_])]) = {
2627
val cmap = emptyClassMap
2728
classes.foreach(toDefinitions(cmap)) // force recording of class definitions
2829
cmap.lz.foreach(_.get()) // force thunks to ensure all inherited dependencies are recorded
2930
val classApis = cmap.allNonLocalClasses.toSeq
31+
val mainClasses = cmap.mainClasses.toSeq
3032
val inDeps = cmap.inherited.toSet
3133
cmap.clear()
32-
(classApis, inDeps)
34+
(classApis, mainClasses, inDeps)
3335
}
3436

3537
// Avoiding implicit allocation.
@@ -55,7 +57,8 @@ object ClassToAPI {
5557
private[sbt] val memo: mutable.Map[String, Seq[api.ClassLikeDef]],
5658
private[sbt] val inherited: mutable.Set[(Class[_], Class[_])],
5759
private[sbt] val lz: mutable.Buffer[xsbti.api.Lazy[_]],
58-
private[sbt] val allNonLocalClasses: mutable.Set[api.ClassLike]
60+
private[sbt] val allNonLocalClasses: mutable.Set[api.ClassLike],
61+
private[sbt] val mainClasses: mutable.Set[String]
5962
) {
6063
def clear(): Unit = {
6164
memo.clear()
@@ -67,6 +70,7 @@ object ClassToAPI {
6770
new ClassMap(new mutable.HashMap,
6871
new mutable.HashSet,
6972
new mutable.ListBuffer,
73+
new mutable.HashSet,
7074
new mutable.HashSet)
7175

7276
def classCanonicalName(c: Class[_]): String =
@@ -115,6 +119,17 @@ object ClassToAPI {
115119
val defsEmptyMembers = clsDef :: statDef :: Nil
116120
cmap.memo(name) = defsEmptyMembers
117121
cmap.allNonLocalClasses ++= defs
122+
123+
if (c.getMethods.exists(
124+
meth =>
125+
meth.getName == "main" &&
126+
Modifier.isStatic(meth.getModifiers) &&
127+
meth.getParameterTypes.length == 1 &&
128+
meth.getParameterTypes.head == classOf[Array[String]] &&
129+
meth.getReturnType == java.lang.Void.TYPE)) {
130+
cmap.mainClasses += name
131+
}
132+
118133
defsEmptyMembers
119134
}
120135

internal/zinc-apiinfo/src/test/scala/sbt/internal/inc/ClassToAPISpecification.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ class ClassToAPISpecification extends UnitSpec {
8080
def readAPI(callback: AnalysisCallback,
8181
source: File,
8282
classes: Seq[Class[_]]): Set[(String, String)] = {
83-
val (apis, inherits) = ClassToAPI.process(classes)
83+
val (apis, mainClasses, inherits) = ClassToAPI.process(classes)
8484
apis.foreach(callback.api(source, _))
85+
mainClasses.foreach(callback.mainClass(source, _))
8586
inherits.map {
8687
case (from, to) => (from.getName, to.getName)
8788
}

internal/zinc-core/src/main/scala/sbt/internal/inc/Compile.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ private final class AnalysisCallback(
149149
private[this] val usedNames = new HashMap[String, Set[UsedName]]
150150
private[this] val unreporteds = new HashMap[File, ListBuffer[Problem]]
151151
private[this] val reporteds = new HashMap[File, ListBuffer[Problem]]
152+
private[this] val mainClasses = new HashMap[File, ListBuffer[String]]
152153
private[this] val binaryDeps = new HashMap[File, Set[File]]
153154
// source file to set of generated (class file, binary class name); only non local classes are stored here
154155
private[this] val nonLocalClasses = new HashMap[File, Set[(File, String)]]
@@ -285,6 +286,11 @@ private final class AnalysisCallback(
285286
}
286287
}
287288

289+
def mainClass(sourceFile: File, className: String): Unit = {
290+
mainClasses.getOrElseUpdate(sourceFile, ListBuffer.empty) += className
291+
()
292+
}
293+
288294
def usedName(className: String, name: String, useScopes: util.EnumSet[UseScope]) =
289295
add(usedNames, className, UsedName(name, useScopes))
290296

@@ -346,7 +352,9 @@ private final class AnalysisCallback(
346352
val stamp = stampReader.source(src)
347353
val classesInSrc = classNames.getOrElse(src, Set.empty).map(_._1)
348354
val analyzedApis = classesInSrc.map(analyzeClass)
349-
val info = SourceInfos.makeInfo(getOrNil(reporteds, src), getOrNil(unreporteds, src))
355+
val info = SourceInfos.makeInfo(getOrNil(reporteds, src),
356+
getOrNil(unreporteds, src),
357+
getOrNil(mainClasses, src))
350358
val binaries = binaryDeps.getOrElse(src, Nil: Iterable[File])
351359
val localProds = localClasses.getOrElse(src, Nil: Iterable[File]) map { classFile =>
352360
val classFileStamp = stampReader.product(classFile)

internal/zinc-core/src/main/scala/sbt/internal/inc/SourceInfo.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ object SourceInfos {
2626
def empty: SourceInfos = make(Map.empty)
2727
def make(m: Map[File, SourceInfo]): SourceInfos = new MSourceInfos(m)
2828

29-
val emptyInfo: SourceInfo = makeInfo(Nil, Nil)
30-
def makeInfo(reported: Seq[Problem], unreported: Seq[Problem]): SourceInfo =
31-
new UnderlyingSourceInfo(reported, unreported)
29+
val emptyInfo: SourceInfo = makeInfo(Nil, Nil, Nil)
30+
def makeInfo(reported: Seq[Problem],
31+
unreported: Seq[Problem],
32+
mainClasses: Seq[String]): SourceInfo =
33+
new UnderlyingSourceInfo(reported, unreported, mainClasses)
3234
def merge(infos: Traversable[SourceInfos]): SourceInfos = (SourceInfos.empty /: infos)(_ ++ _)
3335
}
3436

@@ -48,8 +50,10 @@ private final class MSourceInfos(val allInfos: Map[File, SourceInfo]) extends So
4850
}
4951

5052
private final class UnderlyingSourceInfo(val reportedProblems: Seq[Problem],
51-
val unreportedProblems: Seq[Problem])
53+
val unreportedProblems: Seq[Problem],
54+
val mainClasses: Seq[String])
5255
extends SourceInfo {
5356
override def getReportedProblems: Array[Problem] = reportedProblems.toArray
5457
override def getUnreportedProblems: Array[Problem] = unreportedProblems.toArray
58+
override def getMainClasses: Array[String] = mainClasses.toArray
5559
}

internal/zinc-persist/src/main/scala/sbt/internal/inc/TextAnalysisFormat.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ class TextAnalysisFormat(override val mappers: AnalysisMappers)
6060
private implicit val analyzedClassFormat: Format[AnalyzedClass] =
6161
AnalyzedClassFormats.analyzedClassFormat
6262
private implicit def infoFormat: Format[SourceInfo] =
63-
wrap[SourceInfo, (Seq[Problem], Seq[Problem])](
64-
si => (si.getReportedProblems, si.getUnreportedProblems), {
65-
case (a, b) => SourceInfos.makeInfo(a, b)
63+
wrap[SourceInfo, (Seq[Problem], Seq[Problem], Seq[String])](
64+
si => (si.getReportedProblems, si.getUnreportedProblems, si.getMainClasses), {
65+
case (a, b, c) => SourceInfos.makeInfo(a, b, c)
6666
})
6767
private implicit def fileHashFormat: Format[FileHash] =
6868
asProduct2((file: File, hash: Int) => new FileHash(file, hash))(h => (h.file, h.hash))

0 commit comments

Comments
 (0)