Skip to content

[rebased/cherry-picked] Switch to JarsToLabels provider and rework/cleanup scala_import #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scala/private/common.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
load("@io_bazel_rules_scala//scala:providers.bzl", "JarsToLabels")

def write_manifest(ctx):
# TODO(bazel-team): I don't think this classpath is what you want
manifest = "Class-Path: \n"
Expand Down Expand Up @@ -104,7 +106,7 @@ def _add_label_of_indirect_jar_to(jars2labels, dependency, jar):
# skylark exposes only labels of direct dependencies.
# to get labels of indirect dependencies we collect them from the providers transitively
if _provider_of_dependency_contains_label_of(dependency, jar):
jars2labels[jar.path] = dependency.jars_to_labels[jar.path]
jars2labels[jar.path] = dependency[JarsToLabels].lookup[jar.path]
else:
jars2labels[jar.path] = "Unknown label of file {jar_path} which came from {dependency_label}".format(
jar_path = jar.path,
Expand All @@ -115,7 +117,7 @@ def _label_already_exists(jars2labels, jar):
return jar.path in jars2labels

def _provider_of_dependency_contains_label_of(dependency, jar):
return hasattr(dependency, "jars_to_labels") and jar.path in dependency.jars_to_labels
return JarsToLabels in dependency and jar.path in dependency[JarsToLabels].lookup

def create_java_provider(scalaattr, transitive_compile_time_jars):
# This is needed because Bazel >=0.7.0 requires ctx.actions and a Java
Expand Down
15 changes: 9 additions & 6 deletions scala/private/rule_impls.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Rules for supporting the Scala language."""
load("@io_bazel_rules_scala//scala:scala_toolchain.bzl", "scala_toolchain")
load("@io_bazel_rules_scala//scala:providers.bzl", "create_scala_provider")
load("@io_bazel_rules_scala//scala:providers.bzl", "create_scala_provider", "JarsToLabels")
load(":common.bzl",
"add_labels_of_jars_to",
"create_java_provider",
Expand Down Expand Up @@ -122,12 +122,14 @@ def _collect_plugin_paths(plugins):
for p in plugins:
if hasattr(p, "path"):
paths.append(p)
elif p[JavaInfo] and p[JavaInfo].full_compile_jars:
paths.extend(p[JavaInfo].full_compile_jars.to_list())
elif hasattr(p, "scala"):
paths.append(p.scala.outputs.jar)
elif hasattr(p, "java"):
paths.extend([j.class_jar for j in p.java.outputs.jars])
# support http_file pointed at a jar. http_jar uses ijar,
# which breaks scala macros
# support http_file pointed at a jar. http_jar uses ijar,
# which breaks scala macros
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is on the elif, can we move it below or unindent? It feels awkward to me being indented but not related to the current block.

elif hasattr(p, "files"):
paths.extend([f for f in p.files if not_sources_jar(f.basename)])
return depset(paths)
Expand Down Expand Up @@ -606,10 +608,11 @@ def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2l

return struct(
files=depset([ctx.outputs.executable]),
providers = [java_provider],
providers = [
JarsToLabels(lookup = jars2labels),
java_provider],
scala = scalaattr,
transitive_rjars = rjars, #calling rules need this for the classpath in the launcher
jars_to_labels = jars2labels,
transitive_rjars = rjars, #calling rules need this for the classpath in the launcher
runfiles=runfiles)

def scala_binary_impl(ctx):
Expand Down
7 changes: 7 additions & 0 deletions scala/providers.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ def create_scala_provider(
transitive_runtime_jars = transitive_runtime_jars,
transitive_exports = [] #needed by intellij plugin
)

JarsToLabels = provider(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make a link to the bazel discussions about putting labels in the jars?

bazelbuild/bazel#4584

doc = 'provides a mapping from jar files to defining labels for improved end user experience',
fields = {
'lookup' : 'dictionary with jar files as keys and labels as values',
},
)
2 changes: 1 addition & 1 deletion scala/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ load("@io_bazel_rules_scala//scala/private:rule_impls.bzl",

load(
"@io_bazel_rules_scala//specs2:specs2_junit.bzl",
"specs2_junit_dependencies"
"specs2_junit_dependencies",
)

_jar_filetype = FileType([".jar"])
Expand Down
253 changes: 148 additions & 105 deletions scala/scala_import.bzl
Original file line number Diff line number Diff line change
@@ -1,112 +1,155 @@
#intellij part is tested manually, tread lightly when changing there
#if you change make sure to manually re-import an intellij project and see imports
#are resolved (not red) and clickable
load(":providers.bzl", "JarsToLabels")

# Note to future authors:
#
# Tread lightly when modifying this code! IntelliJ support needs
# to be tested manually: manually [re-]import an intellij project
# and ensure imports are resolved (not red) and clickable
#

def _scala_import_impl(ctx):
target_data = _code_jars_and_intellij_metadata_from(ctx.attr.jars)
(current_target_compile_jars, intellij_metadata) = (target_data.code_jars, target_data.intellij_metadata)
current_jars = depset(current_target_compile_jars)
exports = _collect(ctx.attr.exports)
transitive_runtime_jars = _collect_runtime(ctx.attr.runtime_deps)
jars = _collect(ctx.attr.deps)
jars2labels = {}
_collect_labels(ctx.attr.deps, jars2labels)
_collect_labels(ctx.attr.exports, jars2labels) #untested
_add_labels_of_current_code_jars(depset(transitive=[current_jars, exports.compile_jars]), ctx.label, jars2labels) #last to override the label of the export compile jars to the current target

direct_binary_jars = []
all_jar_files = []
for jar in ctx.attr.jars:
for file in jar.files.to_list():
all_jar_files.append(file)
if not file.basename.endswith("-sources.jar"):
direct_binary_jars += [file]

default_info = DefaultInfo(
files = depset(all_jar_files)
)

source_jar = None
if (ctx.attr.srcjar):
source_jar = ctx.file.srcjar

return struct(
scala = struct(
outputs = struct (
jars = intellij_metadata
),
),
jars_to_labels = jars2labels,
scala = _create_intellij_provider(direct_binary_jars, source_jar),
providers = [
_create_provider(current_jars, transitive_runtime_jars, jars, exports)
],
default_info,
_scala_import_java_info(ctx, direct_binary_jars, source_jar),
_scala_import_jars_to_labels(ctx, direct_binary_jars),
]
)

# The IntelliJ plugin currently does not support JavaInfo. It has its own
# provider. We build that provider and return it in addition to JavaInfo.
# From reading the IntelliJ plugin code, best I can tell it expects a provider
# that looks like this.
# {
# scala: {
# annotation_processing: {
# # see https://docs.bazel.build/versions/master/skylark/lib/java_annotation_processing.html
# },
# outputs: {
# # see https://docs.bazel.build/versions/master/skylark/lib/java_output_jars.html
# jdeps: <file>
# jars: [
# {
# # see https://docs.bazel.build/versions/master/skylark/lib/java_output.html
# class_jar: <file>,
# ijar: <file>
# source_jar: <file>
# source_jars: [<file>...]
# }
# ]
# }
# },
# }
def _create_intellij_provider(jars, source_jar):
return struct(
# TODO: should we support annotation_processing and jdeps?
outputs = struct(
jars = [_create_intellij_output(jar, source_jar) for jar in jars]
)
)

def _create_intellij_output(class_jar, source_jar):
source_jars = [source_jar] if source_jar else []
return struct(
class_jar = class_jar,
ijar = None,
source_jar = source_jar,
source_jars = source_jars,
)
def _create_provider(current_target_compile_jars, transitive_runtime_jars, jars, exports):
test_provider = java_common.create_provider()
if hasattr(test_provider, "full_compile_jars"):
return java_common.create_provider(
use_ijar = False,
compile_time_jars = depset(transitive = [current_target_compile_jars, exports.compile_jars]),
transitive_compile_time_jars = depset(transitive = [jars.transitive_compile_jars, current_target_compile_jars, exports.transitive_compile_jars]) ,
transitive_runtime_jars = depset(transitive = [transitive_runtime_jars, jars.transitive_runtime_jars, current_target_compile_jars, exports.transitive_runtime_jars]) ,
)
else:
return java_common.create_provider(
compile_time_jars = current_target_compile_jars,
runtime_jars = transitive_runtime_jars + jars.transitive_runtime_jars,
transitive_compile_time_jars = jars.transitive_compile_jars + current_target_compile_jars,
transitive_runtime_jars = transitive_runtime_jars + jars.transitive_runtime_jars + current_target_compile_jars,
)

def _add_labels_of_current_code_jars(code_jars, label, jars2labels):
for jar in code_jars.to_list():
jars2labels[jar.path] = label

def _code_jars_and_intellij_metadata_from(jars):
code_jars = []
intellij_metadata = []
for jar in jars:
current_jar_code_jars = _filter_out_non_code_jars(jar.files)
code_jars += current_jar_code_jars
for current_class_jar in current_jar_code_jars: #intellij, untested
intellij_metadata.append(struct(
ijar = None,
class_jar = current_class_jar,
source_jar = None,
source_jars = [],
)
)
return struct(code_jars = code_jars, intellij_metadata = intellij_metadata)

def _filter_out_non_code_jars(files):
return [file for file in files.to_list() if not _is_source_jar(file)]

def _is_source_jar(file):
return file.basename.endswith("-sources.jar")

def _collect(deps):
transitive_compile_jars = []
runtime_jars = []
compile_jars = []

for dep_target in deps:
java_provider = dep_target[java_common.provider]
compile_jars.append(java_provider.compile_jars)
transitive_compile_jars.append(java_provider.transitive_compile_time_jars)
runtime_jars.append(java_provider.transitive_runtime_jars)

return struct(transitive_runtime_jars = depset(transitive = runtime_jars),
transitive_compile_jars = depset(transitive = transitive_compile_jars),
compile_jars = depset(transitive = compile_jars))

def _collect_labels(deps, jars2labels):
for dep_target in deps:
java_provider = dep_target[java_common.provider]
_transitively_accumulate_labels(dep_target, java_provider,jars2labels)

def _transitively_accumulate_labels(dep_target, java_provider, jars2labels):
if hasattr(dep_target, "jars_to_labels"):
jars2labels.update(dep_target.jars_to_labels)
#scala_library doesn't add labels to the direct dependency itself
for jar in java_provider.compile_jars.to_list():
jars2labels[jar.path] = dep_target.label

def _collect_runtime(runtime_deps):
jar_deps = []
for dep_target in runtime_deps:
java_provider = dep_target[java_common.provider]
jar_deps.append(java_provider.transitive_runtime_jars)

return depset(transitive = jar_deps)

def _scala_import_java_info(ctx, direct_binary_jars, source_jar = None):
s_deps = java_common.merge(_collect(JavaInfo, ctx.attr.deps))
s_exports = java_common.merge(_collect(JavaInfo, ctx.attr.exports))
s_runtime_deps = java_common.merge(_collect(JavaInfo, ctx.attr.runtime_deps))

# build up our final JavaInfo provider

compile_time_jars = depset(
direct = direct_binary_jars,
transitive = [
s_exports.transitive_compile_time_jars])

transitive_compile_time_jars = depset(
transitive = [
compile_time_jars,
s_deps.transitive_compile_time_jars,
s_exports.transitive_compile_time_jars])

transitive_runtime_jars = depset(
transitive = [
compile_time_jars,
s_deps.transitive_runtime_jars,
s_exports.transitive_runtime_jars,
s_runtime_deps.transitive_runtime_jars])

source_jars = [source_jar] if source_jar else []

return java_common.create_provider(
ctx.actions,
use_ijar = False,
compile_time_jars = compile_time_jars,
transitive_compile_time_jars = transitive_compile_time_jars,
transitive_runtime_jars = transitive_runtime_jars,
source_jars = source_jars)

def _scala_import_jars_to_labels(ctx, direct_binary_jars):
# build up JarsToLabels
# note: consider moving this to an aspect

lookup = {}
for jar in direct_binary_jars:
lookup[jar.path] = ctx.label

for entry in ctx.attr.deps:
if JavaInfo in entry:
for jar in entry[JavaInfo].compile_jars:
lookup[jar.path] = entry.label
if JarsToLabels in entry:
lookup.update(entry[JarsToLabels].lookup)

for entry in ctx.attr.exports:
if JavaInfo in entry:
for jar in entry[JavaInfo].compile_jars.to_list():
lookup[jar.path] = entry.label
if JarsToLabels in entry:
lookup.update(entry[JarsToLabels].lookup)

return JarsToLabels(lookup = lookup)

# Filters an iterable for entries that contain a particular
# index and returns a collection of the indexed values.
def _collect(index, iterable):
return [
entry[index]
for entry in iterable
if index in entry
]

scala_import = rule(
implementation=_scala_import_impl,
attrs={
"jars": attr.label_list(allow_files=True), #current hidden assumption is that these point to full, not ijar'd jars
"deps": attr.label_list(),
"runtime_deps": attr.label_list(),
"exports": attr.label_list()
},
implementation = _scala_import_impl,
attrs = {
"jars": attr.label_list(allow_files=True), #current hidden assumption is that these point to full, not ijar'd jars
"srcjar": attr.label(allow_single_file=True),
"deps": attr.label_list(),
"runtime_deps": attr.label_list(),
"exports": attr.label_list(),
},
)
5 changes: 3 additions & 2 deletions scala_proto/scala_proto.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,12 @@ def _gen_proto_srcjar_impl(ctx):
acc_imports.append(target.proto.transitive_sources)
#inline this if after 0.12.0 is the oldest supported version
if hasattr(target.proto, 'transitive_proto_path'):
transitive_proto_paths.append(target.proto.transitive_proto_path)
transitive_proto_paths.append(target.proto.transitive_proto_path)
else:
jvm_deps.append(target)

acc_imports = depset(transitive = acc_imports)
transitive_proto_paths = depset(transitive = transitive_proto_paths)
if "java_conversions" in ctx.attr.flags and len(jvm_deps) == 0:
fail("must have at least one jvm dependency if with_java is True (java_conversions is turned on)")

Expand All @@ -352,7 +353,7 @@ def _gen_proto_srcjar_impl(ctx):
# Command line args to worker cannot be empty so using padding
flags_arg = "-" + ",".join(ctx.attr.flags),
# Command line args to worker cannot be empty so using padding
packages = "-" + ":".join(transitive_proto_paths)
packages = "-" + ":".join(transitive_proto_paths.to_list())
)
argfile = ctx.actions.declare_file("%s_worker_input" % ctx.label.name, sibling = ctx.outputs.srcjar)
ctx.actions.write(output=argfile, content=worker_content)
Expand Down