Skip to content

Phases can override providers from previous phases #948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/customizable_phase.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,12 @@ Currently phase architecture is used by 7 rules:
- scala_junit_test
- scala_repl

If you need to expose providers to downstream targets you need to return an array of providers from your phase under the `external_providers` attribute.
If you need to expose providers to downstream targets you need to return a dict of providers (provider-name to provider instance) from your phase under the `external_providers` attribute.

In each of the rule implementations, it calls `run_phases` and returns an accumulated `external_providers` array declared by the phases.
If you need to override a provider returned by a previous phase you can adjust your phase to be after it and return the same key from your phase and it will override it.
Note you probably have a good reason to override since you're meddling with the public return value of a different phase.

In each of the rule implementations, it calls `run_phases` and returns the accumulated values of the `external_providers` dict declared by the phases.

To make a new phase, you have to define a new `phase_<PHASE_NAME>.bzl` in `scala/private/phases/`. Function definition should have 2 arguments, `ctx` and `p`. You may expose the information for later phases by returning a `struct`. In some phases, there are multiple phase functions since different rules may take slightly different input arguemnts. You may want to re-expose the phase definition in `scala/private/phases/phases.bzl`, so it's more convenient to access in rule files.

Expand Down
10 changes: 7 additions & 3 deletions scala/private/phases/api.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ load(
"@io_bazel_rules_scala//scala:advanced_usage/providers.bzl",
_ScalaRulePhase = "ScalaRulePhase",
)
load("@bazel_skylib//lib:dicts.bzl", "dicts")

# A method to modify the built-in phase list
# - Insert new phases to the first/last position
Expand Down Expand Up @@ -63,7 +64,7 @@ def run_phases(ctx, builtin_customizable_phases):
# A placeholder for data shared with later phases
global_provider = {}
current_provider = struct(**global_provider)
acculmulated_external_providers = []
acculmulated_external_providers = {}
for (name, function) in adjusted_phases:
# Run a phase
new_provider = function(ctx, current_provider)
Expand All @@ -72,12 +73,15 @@ def run_phases(ctx, builtin_customizable_phases):
# for later phases to access
if new_provider != None:
if (hasattr(new_provider, "external_providers")):
acculmulated_external_providers.extend(new_provider.external_providers)
acculmulated_external_providers = dicts.add(
acculmulated_external_providers,
new_provider.external_providers,
)
global_provider[name] = new_provider
current_provider = struct(**global_provider)

# The final return of rules implementation
return acculmulated_external_providers
return acculmulated_external_providers.values()

# A method to pass in phase provider
def extras_phases(extras):
Expand Down
2 changes: 1 addition & 1 deletion scala/private/phases/phase_collect_jars.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _phase_collect_jars(
transitive_compile_jars = transitive_compile_jars,
transitive_runtime_jars = transitive_rjars,
deps_providers = deps_providers,
external_providers = [jars2labels],
external_providers = {"JarsToLabelsInfo": jars2labels},
)

def _collect_runtime_jars(dep_targets):
Expand Down
23 changes: 16 additions & 7 deletions scala/private/phases/phase_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DOCUMENT THIS
#
load("@bazel_skylib//lib:paths.bzl", _paths = "paths")
load("@bazel_skylib//lib:dicts.bzl", "dicts")
load("@bazel_tools//tools/jdk:toolchain_utils.bzl", "find_java_runtime_toolchain", "find_java_toolchain")
load(
"@io_bazel_rules_scala//scala/private:coverage_replacements_provider.bzl",
Expand All @@ -23,9 +24,10 @@ _scala_extension = ".scala"
_srcjar_extension = ".srcjar"

_empty_coverage_struct = struct(
instrumented_files = None,
providers = [],
replacements = {},
external = struct(
replacements = {},
),
providers_dict = {},
)

def phase_binary_compile(ctx, p):
Expand Down Expand Up @@ -162,15 +164,17 @@ def _phase_compile(
# TODO: simplify the return values and use provider
return struct(
class_jar = out.class_jar,
coverage = out.coverage,
coverage = out.coverage.external,
full_jars = out.full_jars,
ijar = out.ijar,
ijars = out.ijars,
rjars = depset(out.full_jars, transitive = [rjars]),
java_jar = out.java_jar,
source_jars = _pack_source_jars(ctx) + out.source_jars,
merged_provider = out.merged_provider,
external_providers = [out.merged_provider] + out.coverage.providers,
external_providers = dicts.add(out.coverage.providers_dict, {
"JavaInfo": out.merged_provider,
}),
)

def _compile_or_empty(
Expand Down Expand Up @@ -422,8 +426,13 @@ def _jacoco_offline_instrument(ctx, input_jar):
extensions = ["scala", "java"],
)
return struct(
providers = [provider, instrumented_files_provider],
replacements = replacements,
external = struct(
replacements = replacements,
),
providers_dict = {
"_CoverageReplacements": provider,
"InstrumentedFilesInfo": instrumented_files_provider,
},
)

def _jacoco_offline_instrument_format_each(in_out_pair):
Expand Down
18 changes: 9 additions & 9 deletions scala/private/phases/phase_default_info.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@
#
def phase_binary_default_info(ctx, p):
return struct(
external_providers = [
DefaultInfo(
external_providers = {
"DefaultInfo": DefaultInfo(
executable = p.declare_executable,
files = depset([p.declare_executable] + p.compile.full_jars),
runfiles = p.runfiles.runfiles,
),
],
},
)

def phase_library_default_info(ctx, p):
return struct(
external_providers = [
DefaultInfo(
external_providers = {
"DefaultInfo": DefaultInfo(
files = depset(p.compile.full_jars),
runfiles = p.runfiles.runfiles,
),
],
},
)

def phase_scalatest_default_info(ctx, p):
return struct(
external_providers = [
DefaultInfo(
external_providers = {
"DefaultInfo": DefaultInfo(
executable = p.declare_executable,
files = depset([p.declare_executable] + p.compile.full_jars),
runfiles = ctx.runfiles(p.coverage_runfiles.coverage_runfiles, transitive_files = p.runfiles.runfiles.files),
),
],
},
)
15 changes: 15 additions & 0 deletions test/phase/providers/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load(":phase_providers_expose.bzl", "phase_expose_provider_singleton", "rule_that_needs_custom_provider", "scala_library_that_exposes_custom_provider")
load(":phase_providers_override.bzl", "phase_override_provider_singleton", "rule_that_has_phases_which_override_providers", "rule_that_verifies_providers_are_overriden")

scala_library_that_exposes_custom_provider(
name = "scala_library_that_exposes_custom_provider",
Expand All @@ -13,3 +14,17 @@ phase_expose_provider_singleton(
name = "phase_expose_provider_singleton_target",
visibility = ["//visibility:public"],
)

rule_that_has_phases_which_override_providers(
name = "PhaseOverridesProvider",
)

rule_that_verifies_providers_are_overriden(
name = "PhaseOverridesProviderTest",
dep = ":PhaseOverridesProvider",
)

phase_override_provider_singleton(
name = "phase_override_provider_singleton_target",
visibility = ["//visibility:public"],
)
2 changes: 1 addition & 1 deletion test/phase/providers/phase_providers_expose.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ CustomProviderExposedByPhase = provider()

def _phase_expose_provider(ctx, p):
return struct(
external_providers = [CustomProviderExposedByPhase()],
external_providers = {"CustomProviderExposedByPhase": CustomProviderExposedByPhase()},
)

def _rule_that_needs_custom_provider_impl(ctx):
Expand Down
61 changes: 61 additions & 0 deletions test/phase/providers/phase_providers_override.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
load("@io_bazel_rules_scala//scala:advanced_usage/providers.bzl", "ScalaRulePhase")
load("@io_bazel_rules_scala//scala:advanced_usage/scala.bzl", "make_scala_library")

ext_phase_override_provider = {
"phase_providers": [
"//test/phase/providers:phase_override_provider_singleton_target",
],
}

rule_that_has_phases_which_override_providers = make_scala_library(ext_phase_override_provider)

def _phase_override_provider_singleton_implementation(ctx):
return [
ScalaRulePhase(
custom_phases = [
("last", "", "first_custom", _phase_original),
("after", "first_custom", "second_custom", _phase_override),
],
),
]

phase_override_provider_singleton = rule(
implementation = _phase_override_provider_singleton_implementation,
)

OverrideProvider = provider(fields = ["content"])

def _phase_original(ctx, p):
return struct(
external_providers = {
"OverrideProvider": OverrideProvider(
content = "original",
),
},
)

def _phase_override(ctx, p):
return struct(
external_providers = {
"OverrideProvider": OverrideProvider(
content = "override",
),
},
)

def _rule_that_verifies_providers_are_overriden_impl(ctx):
if (ctx.attr.dep[OverrideProvider].content != "override"):
fail(
"expected OverrideProvider of {label} to have content 'override' but got '{content}'".format(
label = ctx.label,
content = ctx.attr.dep[OverrideProvider].content,
),
)
return []

rule_that_verifies_providers_are_overriden = rule(
implementation = _rule_that_verifies_providers_are_overriden_impl,
attrs = {
"dep": attr.label(providers = [OverrideProvider]),
},
)