Skip to content

Refactor write_launcher #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 42 additions & 29 deletions scala/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,20 @@ def _path_is_absolute(path):

return False

def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble=""):
runfiles_root = "${TEST_SRCDIR}/%s" % ctx.workspace_name
# RUNPATH is defined here:
# https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227
classpath = ":".join(["${RUNPATH}%s" % (j.short_path) for j in rjars])
jvm_flags = " ".join([ctx.expand_location(f, ctx.attr.data) for f in jvm_flags])
def _runfiles_root(ctx):
return "${TEST_SRCDIR}/%s" % ctx.workspace_name

def _write_java_wrapper(ctx, args="", wrapper_preamble=""):
"""This creates a wrapper that sets up the correct path
to stand in for the java command."""

java_path = str(ctx.attr._java_runtime[java_common.JavaRuntimeInfo].java_executable_runfiles_path)
if _path_is_absolute(java_path):
javabin = java_path
else:
runfiles_root = _runfiles_root(ctx)
javabin = "%s/%s" % (runfiles_root, java_path)

template = ctx.attr._java_stub_template.files.to_list()[0]

exec_str = ""
if wrapper_preamble == "":
Expand All @@ -457,14 +458,21 @@ def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble
args=args,
),
)
return wrapper

def _write_executable(ctx, rjars, main_class, jvm_flags, wrapper):
template = ctx.attr._java_stub_template.files.to_list()[0]
# RUNPATH is defined here:
# https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227
classpath = ":".join(["${RUNPATH}%s" % (j.short_path) for j in rjars])
jvm_flags = " ".join([ctx.expand_location(f, ctx.attr.data) for f in jvm_flags])
ctx.template_action(
template = template,
output = ctx.outputs.executable,
substitutions = {
"%classpath%": classpath,
"%java_start_class%": main_class,
"%javabin%": "JAVABIN=%s/%s" % (runfiles_root, wrapper.short_path),
"%javabin%": "JAVABIN=%s/%s" % (_runfiles_root(ctx), wrapper.short_path),
"%jvm_flags%": jvm_flags,
"%needs_runfiles%": "",
"%runfiles_manifest_only%": "",
Expand Down Expand Up @@ -739,17 +747,16 @@ def _scala_macro_library_impl(ctx):
return _lib(ctx, False) # don't build the ijar for macros

# Common code shared by all scala binary implementations.
def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2labels, implicit_junit_deps_needed_for_java_compilation = []):
def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2labels, java_wrapper, implicit_junit_deps_needed_for_java_compilation = []):
write_manifest(ctx)
outputs = _compile_or_empty(ctx, cjars, [], False, transitive_compile_time_jars, jars2labels, implicit_junit_deps_needed_for_java_compilation) # no need to build an ijar for an executable
rjars += outputs.full_jars

_build_deployable(ctx, list(rjars))

java_wrapper = ctx.new_file(ctx.label.name + "_wrapper.sh")
rjars_list = list(rjars)
_build_deployable(ctx, rjars_list)

runfiles = ctx.runfiles(
files = list(rjars) + [ctx.outputs.executable, java_wrapper] + ctx.files._java_runtime,
files = rjars_list + [ctx.outputs.executable, java_wrapper] + ctx.files._java_runtime,
collect_data = True)

rule_outputs = struct(
Expand Down Expand Up @@ -781,12 +788,14 @@ def _scala_binary_impl(ctx):
jars = _collect_jars_from_common_ctx(ctx)
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)

out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels)
_write_launcher(
wrapper = _write_java_wrapper(ctx, "", "")
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper)
_write_executable(
ctx = ctx,
rjars = out.transitive_rjars,
main_class = ctx.attr.main_class,
jvm_flags = ctx.attr.jvm_flags,
wrapper = wrapper
)
return out

Expand All @@ -795,15 +804,8 @@ def _scala_repl_impl(ctx):
jars = _collect_jars_from_common_ctx(ctx, extra_runtime_deps = [ctx.attr._scalacompiler])
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)

out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels)
args = " ".join(ctx.attr.scalacopts)
_write_launcher(
ctx = ctx,
rjars = out.transitive_rjars,
main_class = "scala.tools.nsc.MainGenericRunner",
jvm_flags = ["-Dscala.usejavacp=true"] + ctx.attr.jvm_flags,
args = args,
wrapper_preamble = """
wrapper = _write_java_wrapper(ctx, args, wrapper_preamble = """
# save stty like in bin/scala
saved_stty=$(stty -g 2>/dev/null)
if [[ ! $? ]]; then
Expand All @@ -816,7 +818,15 @@ function finish() {
fi
}
trap finish EXIT
""",
""")

out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper)
_write_executable(
ctx = ctx,
rjars = out.transitive_rjars,
main_class = "scala.tools.nsc.MainGenericRunner",
jvm_flags = ["-Dscala.usejavacp=true"] + ctx.attr.jvm_flags,
wrapper = wrapper
)

return out
Expand Down Expand Up @@ -857,14 +867,15 @@ def _scala_test_impl(ctx):
_scala_test_flags(ctx),
"-C io.bazel.rules.scala.JUnitXmlReporter ",
])
out = _scala_binary_common(ctx, cjars, transitive_rjars, transitive_compile_jars, jars_to_labels)
# main_class almost has to be "org.scalatest.tools.Runner" due to args....
_write_launcher(
wrapper = _write_java_wrapper(ctx, args, "")
out = _scala_binary_common(ctx, cjars, transitive_rjars, transitive_compile_jars, jars_to_labels, wrapper)
_write_executable(
ctx = ctx,
rjars = out.transitive_rjars,
main_class = ctx.attr.main_class,
jvm_flags = ctx.attr.jvm_flags,
args = args,
wrapper = wrapper
)
return out

Expand All @@ -890,14 +901,16 @@ def _scala_junit_test_impl(ctx):
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)
implicit_junit_deps_needed_for_java_compilation = [ctx.attr._junit, ctx.attr._hamcrest]

out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, implicit_junit_deps_needed_for_java_compilation)
wrapper = _write_java_wrapper(ctx, "", "")
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper, implicit_junit_deps_needed_for_java_compilation)
test_suite = _gen_test_suite_flags_based_on_prefixes_and_suffixes(ctx, out.scala.outputs.jars)
launcherJvmFlags = ["-ea", test_suite.archiveFlag, test_suite.prefixesFlag, test_suite.suffixesFlag, test_suite.printFlag, test_suite.testSuiteFlag]
_write_launcher(
_write_executable(
ctx = ctx,
rjars = out.transitive_rjars,
main_class = "com.google.testing.junit.runner.BazelTestRunner",
jvm_flags = launcherJvmFlags + ctx.attr.jvm_flags,
wrapper = wrapper
)

return out
Expand Down