diff --git a/scala/scala.bzl b/scala/scala.bzl index c94e77d65..d7dc1ebf7 100644 --- a/scala/scala.bzl +++ b/scala/scala.bzl @@ -425,19 +425,20 @@ def _path_is_absolute(path): return False -def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble=""): - runfiles_root = "${TEST_SRCDIR}/%s" % ctx.workspace_name - # RUNPATH is defined here: - # https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227 - classpath = ":".join(["${RUNPATH}%s" % (j.short_path) for j in rjars]) - jvm_flags = " ".join([ctx.expand_location(f, ctx.attr.data) for f in jvm_flags]) +def _runfiles_root(ctx): + return "${TEST_SRCDIR}/%s" % ctx.workspace_name + +def _write_java_wrapper(ctx, args="", wrapper_preamble=""): + """This creates a wrapper that sets up the correct path + to stand in for the java command.""" + java_path = str(ctx.attr._java_runtime[java_common.JavaRuntimeInfo].java_executable_runfiles_path) if _path_is_absolute(java_path): javabin = java_path else: + runfiles_root = _runfiles_root(ctx) javabin = "%s/%s" % (runfiles_root, java_path) - template = ctx.attr._java_stub_template.files.to_list()[0] exec_str = "" if wrapper_preamble == "": @@ -457,14 +458,21 @@ def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble args=args, ), ) + return wrapper +def _write_executable(ctx, rjars, main_class, jvm_flags, wrapper): + template = ctx.attr._java_stub_template.files.to_list()[0] + # RUNPATH is defined here: + # https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227 + classpath = ":".join(["${RUNPATH}%s" % (j.short_path) for j in rjars]) + jvm_flags = " ".join([ctx.expand_location(f, ctx.attr.data) for f in jvm_flags]) ctx.template_action( template = template, output = ctx.outputs.executable, substitutions = { "%classpath%": classpath, "%java_start_class%": main_class, - "%javabin%": "JAVABIN=%s/%s" % (runfiles_root, wrapper.short_path), + "%javabin%": "JAVABIN=%s/%s" % (_runfiles_root(ctx), wrapper.short_path), "%jvm_flags%": jvm_flags, "%needs_runfiles%": "", "%runfiles_manifest_only%": "", @@ -739,17 +747,16 @@ def _scala_macro_library_impl(ctx): return _lib(ctx, False) # don't build the ijar for macros # Common code shared by all scala binary implementations. -def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2labels, implicit_junit_deps_needed_for_java_compilation = []): +def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2labels, java_wrapper, implicit_junit_deps_needed_for_java_compilation = []): write_manifest(ctx) outputs = _compile_or_empty(ctx, cjars, [], False, transitive_compile_time_jars, jars2labels, implicit_junit_deps_needed_for_java_compilation) # no need to build an ijar for an executable rjars += outputs.full_jars - _build_deployable(ctx, list(rjars)) - - java_wrapper = ctx.new_file(ctx.label.name + "_wrapper.sh") + rjars_list = list(rjars) + _build_deployable(ctx, rjars_list) runfiles = ctx.runfiles( - files = list(rjars) + [ctx.outputs.executable, java_wrapper] + ctx.files._java_runtime, + files = rjars_list + [ctx.outputs.executable, java_wrapper] + ctx.files._java_runtime, collect_data = True) rule_outputs = struct( @@ -781,12 +788,14 @@ def _scala_binary_impl(ctx): jars = _collect_jars_from_common_ctx(ctx) (cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars) - out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels) - _write_launcher( + wrapper = _write_java_wrapper(ctx, "", "") + out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper) + _write_executable( ctx = ctx, rjars = out.transitive_rjars, main_class = ctx.attr.main_class, jvm_flags = ctx.attr.jvm_flags, + wrapper = wrapper ) return out @@ -795,15 +804,8 @@ def _scala_repl_impl(ctx): jars = _collect_jars_from_common_ctx(ctx, extra_runtime_deps = [ctx.attr._scalacompiler]) (cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars) - out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels) args = " ".join(ctx.attr.scalacopts) - _write_launcher( - ctx = ctx, - rjars = out.transitive_rjars, - main_class = "scala.tools.nsc.MainGenericRunner", - jvm_flags = ["-Dscala.usejavacp=true"] + ctx.attr.jvm_flags, - args = args, - wrapper_preamble = """ + wrapper = _write_java_wrapper(ctx, args, wrapper_preamble = """ # save stty like in bin/scala saved_stty=$(stty -g 2>/dev/null) if [[ ! $? ]]; then @@ -816,7 +818,15 @@ function finish() { fi } trap finish EXIT -""", +""") + + out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper) + _write_executable( + ctx = ctx, + rjars = out.transitive_rjars, + main_class = "scala.tools.nsc.MainGenericRunner", + jvm_flags = ["-Dscala.usejavacp=true"] + ctx.attr.jvm_flags, + wrapper = wrapper ) return out @@ -857,14 +867,15 @@ def _scala_test_impl(ctx): _scala_test_flags(ctx), "-C io.bazel.rules.scala.JUnitXmlReporter ", ]) - out = _scala_binary_common(ctx, cjars, transitive_rjars, transitive_compile_jars, jars_to_labels) # main_class almost has to be "org.scalatest.tools.Runner" due to args.... - _write_launcher( + wrapper = _write_java_wrapper(ctx, args, "") + out = _scala_binary_common(ctx, cjars, transitive_rjars, transitive_compile_jars, jars_to_labels, wrapper) + _write_executable( ctx = ctx, rjars = out.transitive_rjars, main_class = ctx.attr.main_class, jvm_flags = ctx.attr.jvm_flags, - args = args, + wrapper = wrapper ) return out @@ -890,14 +901,16 @@ def _scala_junit_test_impl(ctx): (cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars) implicit_junit_deps_needed_for_java_compilation = [ctx.attr._junit, ctx.attr._hamcrest] - out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, implicit_junit_deps_needed_for_java_compilation) + wrapper = _write_java_wrapper(ctx, "", "") + out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper, implicit_junit_deps_needed_for_java_compilation) test_suite = _gen_test_suite_flags_based_on_prefixes_and_suffixes(ctx, out.scala.outputs.jars) launcherJvmFlags = ["-ea", test_suite.archiveFlag, test_suite.prefixesFlag, test_suite.suffixesFlag, test_suite.printFlag, test_suite.testSuiteFlag] - _write_launcher( + _write_executable( ctx = ctx, rjars = out.transitive_rjars, main_class = "com.google.testing.junit.runner.BazelTestRunner", jvm_flags = launcherJvmFlags + ctx.attr.jvm_flags, + wrapper = wrapper ) return out