Skip to content

Cherry pick and merge #465 into scala 2.12 branch #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .bazelci/presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
platforms:
ubuntu1404:
build_targets:
- "//test/..."
test_targets:
- "//test/..."
ubuntu1604:
build_targets:
- "//test/..."
test_targets:
- "//test/..."
macos:
build_targets:
- "//test/..."
test_targets:
- "//test/..."
22 changes: 12 additions & 10 deletions jmh/jmh.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@ load("//scala:scala.bzl", "scala_binary", "scala_library")
def jmh_repositories():
native.maven_jar(
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
artifact = "org.openjdk.jmh:jmh-core:1.17.4",
sha1 = "126d989b196070a8b3653b5389e602a48fe6bb2f",
artifact = "org.openjdk.jmh:jmh-core:1.20",
sha1 = "5f9f9839bda2332e9acd06ce31ad94afa7d6d447",
)
native.bind(
name = 'io_bazel_rules_scala/dependency/jmh/jmh_core',
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_core//jar',
)
native.maven_jar(
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
artifact = "org.openjdk.jmh:jmh-generator-asm:1.17.4",
sha1 = "c85c3d8cfa194872b260e89770d41e2084ce2cb6",
artifact = "org.openjdk.jmh:jmh-generator-asm:1.20",
sha1 = "3c43040e08ae68905657a375e669f11a7352f9db",
)
native.bind(
name = 'io_bazel_rules_scala/dependency/jmh/jmh_generator_asm',
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm//jar',
)
native.maven_jar(
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
artifact = "org.openjdk.jmh:jmh-generator-reflection:1.17.4",
sha1 = "f75a7823c9fcf03feed6d74aa44ea61fc70a8439",
artifact = "org.openjdk.jmh:jmh-generator-reflection:1.20",
sha1 = "f2154437b42426a48d5dac0b3df59002f86aed26",
)
native.bind(
name = 'io_bazel_rules_scala/dependency/jmh/jmh_generator_reflection',
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection//jar',
)
native.maven_jar(
name = "io_bazel_rules_scala_org_ows2_asm_asm",
artifact = "org.ow2.asm:asm:5.0.4",
sha1 = "0da08b8cce7bbf903602a25a3a163ae252435795",
artifact = "org.ow2.asm:asm:6.1.1",
sha1 = "264754515362d92acd39e8d40395f6b8dee7bc08",
)
native.bind(
name = "io_bazel_rules_scala/dependency/jmh/org_ows2_asm_asm",
Expand Down Expand Up @@ -78,14 +78,15 @@ def _scala_generate_benchmark(ctx):
outputs = [ctx.outputs.src_jar, ctx.outputs.resource_jar],
inputs = [class_jar] + classpath,
executable = ctx.executable._generator,
arguments = [f.path for f in [class_jar, ctx.outputs.src_jar, ctx.outputs.resource_jar] + classpath],
arguments = [ctx.attr.generator_type] + [f.path for f in [class_jar, ctx.outputs.src_jar, ctx.outputs.resource_jar] + classpath],
progress_message = "Generating benchmark code for %s" % ctx.label,
)

scala_generate_benchmark = rule(
implementation = _scala_generate_benchmark,
attrs = {
"src": attr.label(allow_single_file=True, mandatory=True),
"generator_type": attr.string(default='reflection', mandatory=False),
"_generator": attr.label(executable=True, cfg="host", default=Label("//src/scala/io/bazel/rules_scala/jmh_support:benchmark_generator"))
},
outputs = {
Expand All @@ -98,6 +99,7 @@ def scala_benchmark_jmh(**kw):
name = kw["name"]
deps = kw.get("deps", [])
srcs = kw["srcs"]
generator_type = kw.get("generator_type", "reflection")
lib = "%s_generator" % name
scalacopts = kw.get("scalacopts", [])
main_class = kw.get("main_class", "org.openjdk.jmh.Main")
Expand All @@ -115,7 +117,7 @@ def scala_benchmark_jmh(**kw):
)

codegen = name + "_codegen"
scala_generate_benchmark(name=codegen, src=lib)
scala_generate_benchmark(name=codegen, src=lib, generator_type=generator_type)
compiled_lib = name + "_compiled_benchmark_lib"
scala_library(
name = compiled_lib,
Expand Down
96 changes: 74 additions & 22 deletions src/scala/io/bazel/rules_scala/jmh_support/BenchmarkGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import java.net.URLClassLoader

import scala.annotation.tailrec
import scala.collection.JavaConverters._

import org.openjdk.jmh.generators.core.{ BenchmarkGenerator => JMHGenerator, FileSystemDestination }
import org.openjdk.jmh.generators.core.{FileSystemDestination, GeneratorSource, BenchmarkGenerator => JMHGenerator}
import org.openjdk.jmh.generators.asm.ASMGeneratorSource
import org.openjdk.jmh.runner.{ Runner, RunnerException }
import org.openjdk.jmh.runner.options.{ Options, OptionsBuilder }

import org.openjdk.jmh.generators.reflection.RFGeneratorSource
import org.openjdk.jmh.runner.{Runner, RunnerException}
import org.openjdk.jmh.runner.options.{Options, OptionsBuilder}
import java.net.URI

import scala.collection.JavaConverters._
import java.nio.file.{Files, FileSystems, Path}
import java.nio.file.{FileSystems, Files, Path, Paths}

import io.bazel.rulesscala.jar.JarCreator

Expand All @@ -27,7 +27,14 @@ import io.bazel.rulesscala.jar.JarCreator
*/
object BenchmarkGenerator {

case class BenchmarkGeneratorArgs(
private sealed trait GeneratorType

private case object ReflectionGenerator extends GeneratorType

private case object AsmGenerator extends GeneratorType

private case class BenchmarkGeneratorArgs(
generatorType: GeneratorType,
inputJar: Path,
resultSourceJar: Path,
resultResourceJar: Path,
Expand All @@ -37,6 +44,7 @@ object BenchmarkGenerator {
def main(argv: Array[String]): Unit = {
val args = parseArgs(argv)
generateJmhBenchmark(
args.generatorType,
args.resultSourceJar,
args.resultResourceJar,
args.inputJar,
Expand All @@ -47,17 +55,18 @@ object BenchmarkGenerator {
private def parseArgs(argv: Array[String]): BenchmarkGeneratorArgs = {
if (argv.length < 3) {
System.err.println(
"Usage: BenchmarkGenerator INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
"Usage: BenchmarkGenerator GENERATOR_TYPE INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
)
System.exit(1)
}
val fs = FileSystems.getDefault

BenchmarkGeneratorArgs(
fs.getPath(argv(0)),
parseGeneratorType(argv(0)),
fs.getPath(argv(1)),
fs.getPath(argv(2)),
argv.slice(3, argv.length).map { s => fs.getPath(s) }.toList
fs.getPath(argv(3)),
argv.slice(4, argv.length).map { s => fs.getPath(s) }.toList
)
}

Expand Down Expand Up @@ -88,13 +97,13 @@ object BenchmarkGenerator {
}

// Courtesy of Doug Tangren (https://groups.google.com/forum/#!topic/simple-build-tool/CYeLHcJjHyA)
private def withClassLoader[A](cp: Seq[Path])(f: => A): A = {
private def withClassLoader[A](cp: Seq[Path])(f: ClassLoader => A): A = {
val originalLoader = Thread.currentThread.getContextClassLoader
val jmhLoader = classOf[JMHGenerator].getClassLoader
val classLoader = new URLClassLoader(cp.map(_.toUri.toURL).toArray, jmhLoader)
try {
Thread.currentThread.setContextClassLoader(classLoader)
f
f(classLoader)
} finally {
Thread.currentThread.setContextClassLoader(originalLoader)
}
Expand All @@ -119,6 +128,7 @@ object BenchmarkGenerator {
}

private def generateJmhBenchmark(
generatorType: GeneratorType,
sourceJarOut: Path,
resourceJarOut: Path,
benchmarkJarPath: Path,
Expand All @@ -131,17 +141,26 @@ object BenchmarkGenerator {
tmpResourceDir.toFile.mkdir()
tmpSourceDir.toFile.mkdir()

withClassLoader(benchmarkJarPath :: classpath) {
val source = new ASMGeneratorSource
val destination = new FileSystemDestination(tmpResourceDir.toFile, tmpSourceDir.toFile)
val generator = new JMHGenerator

collectClassesFromJar(benchmarkJarPath).foreach { path =>
// this would fail due to https://github.com/bazelbuild/rules_scala/issues/295
// let's throw a useful message instead
sys.error("jmh in rules_scala doesn't work with Java 8 bytecode: https://github.com/bazelbuild/rules_scala/issues/295")
source.processClass(Files.newInputStream(path))
withClassLoader(benchmarkJarPath :: classpath) { isolatedClassLoader =>

val source: GeneratorSource = generatorType match {
case AsmGenerator =>
val generatorSource = new ASMGeneratorSource
generatorSource.processClasses(collectClassesFromJar(benchmarkJarPath).map(_.toFile).asJavaCollection)
generatorSource

case ReflectionGenerator =>
val generatorSource = new RFGeneratorSource
generatorSource.processClasses(
collectClassesFromJar(benchmarkJarPath)
.flatMap(classByPath(_, isolatedClassLoader))
.asJavaCollection
)
generatorSource
}

val generator = new JMHGenerator
val destination = new FileSystemDestination(tmpResourceDir.toFile, tmpSourceDir.toFile)
generator.generate(source, destination)
generator.complete(source, destination)
if (destination.hasErrors) {
Expand All @@ -156,6 +175,39 @@ object BenchmarkGenerator {
}
}

private def classByPath(path: Path, cl: ClassLoader): Option[Class[_]] = {
val separator = path.getFileSystem.getSeparator
var s = path.toString
.stripPrefix(separator)
.stripSuffix(".class")
.replace(separator, ".")

var index = -1
do {
s = s.substring(index + 1)
try {
return Some(Class.forName(s, false, cl))
} catch {
case _: ClassNotFoundException =>
// ignore and try next one
index = s.indexOf('.')
}
} while (index != -1)

log(s"Failed to find class for path $path")
None
}

private def parseGeneratorType(s: String): GeneratorType = {
if ("asm".equalsIgnoreCase(s)) {
AsmGenerator
} else if ("reflection".equalsIgnoreCase(s)) {
ReflectionGenerator
} else {
throw new IllegalArgumentException(s"unknown generator_type: $s")
}
}

private def log(str: String): Unit = {
System.err.println(s"JMH benchmark generation: $str")
}
Expand Down
11 changes: 5 additions & 6 deletions test/jmh/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ scala_library(
visibility = ["//visibility:public"],
)

# Disable the jmh test due to https://github.com/bazelbuild/rules_scala/issues/295
# scala_benchmark_jmh(
# name = "test_benchmark",
# srcs = ["TestBenchmark.scala"],
# deps = [":add_numbers"],
# )
scala_benchmark_jmh(
name = "test_benchmark",
srcs = ["TestBenchmark.scala"],
deps = [":add_numbers"],
)