diff --git a/docs/customizable_phase.md b/docs/customizable_phase.md index 58a81648c..bf689e645 100644 --- a/docs/customizable_phase.md +++ b/docs/customizable_phase.md @@ -128,9 +128,12 @@ Currently phase architecture is used by 7 rules: - scala_junit_test - scala_repl -If you need to expose providers to downstream targets you need to return an array of providers from your phase under the `external_providers` attribute. +If you need to expose providers to downstream targets you need to return a dict of providers (provider-name to provider instance) from your phase under the `external_providers` attribute. -In each of the rule implementations, it calls `run_phases` and returns an accumulated `external_providers` array declared by the phases. +If you need to override a provider returned by a previous phase you can adjust your phase to be after it and return the same key from your phase and it will override it. +Note you probably have a good reason to override since you're meddling with the public return value of a different phase. + +In each of the rule implementations, it calls `run_phases` and returns the accumulated values of the `external_providers` dict declared by the phases. To make a new phase, you have to define a new `phase_.bzl` in `scala/private/phases/`. Function definition should have 2 arguments, `ctx` and `p`. You may expose the information for later phases by returning a `struct`. In some phases, there are multiple phase functions since different rules may take slightly different input arguemnts. You may want to re-expose the phase definition in `scala/private/phases/phases.bzl`, so it's more convenient to access in rule files. diff --git a/scala/private/phases/api.bzl b/scala/private/phases/api.bzl index a1bd4a60c..d8da71cd5 100644 --- a/scala/private/phases/api.bzl +++ b/scala/private/phases/api.bzl @@ -6,6 +6,7 @@ load( "@io_bazel_rules_scala//scala:advanced_usage/providers.bzl", _ScalaRulePhase = "ScalaRulePhase", ) +load("@bazel_skylib//lib:dicts.bzl", "dicts") # A method to modify the built-in phase list # - Insert new phases to the first/last position @@ -63,7 +64,7 @@ def run_phases(ctx, builtin_customizable_phases): # A placeholder for data shared with later phases global_provider = {} current_provider = struct(**global_provider) - acculmulated_external_providers = [] + acculmulated_external_providers = {} for (name, function) in adjusted_phases: # Run a phase new_provider = function(ctx, current_provider) @@ -72,12 +73,15 @@ def run_phases(ctx, builtin_customizable_phases): # for later phases to access if new_provider != None: if (hasattr(new_provider, "external_providers")): - acculmulated_external_providers.extend(new_provider.external_providers) + acculmulated_external_providers = dicts.add( + acculmulated_external_providers, + new_provider.external_providers, + ) global_provider[name] = new_provider current_provider = struct(**global_provider) # The final return of rules implementation - return acculmulated_external_providers + return acculmulated_external_providers.values() # A method to pass in phase provider def extras_phases(extras): diff --git a/scala/private/phases/phase_collect_jars.bzl b/scala/private/phases/phase_collect_jars.bzl index 43183826d..becf6e518 100644 --- a/scala/private/phases/phase_collect_jars.bzl +++ b/scala/private/phases/phase_collect_jars.bzl @@ -108,7 +108,7 @@ def _phase_collect_jars( transitive_compile_jars = transitive_compile_jars, transitive_runtime_jars = transitive_rjars, deps_providers = deps_providers, - external_providers = [jars2labels], + external_providers = {"JarsToLabelsInfo": jars2labels}, ) def _collect_runtime_jars(dep_targets): diff --git a/scala/private/phases/phase_compile.bzl b/scala/private/phases/phase_compile.bzl index a3dacc480..23393bdac 100644 --- a/scala/private/phases/phase_compile.bzl +++ b/scala/private/phases/phase_compile.bzl @@ -4,6 +4,7 @@ # DOCUMENT THIS # load("@bazel_skylib//lib:paths.bzl", _paths = "paths") +load("@bazel_skylib//lib:dicts.bzl", "dicts") load("@bazel_tools//tools/jdk:toolchain_utils.bzl", "find_java_runtime_toolchain", "find_java_toolchain") load( "@io_bazel_rules_scala//scala/private:coverage_replacements_provider.bzl", @@ -23,9 +24,10 @@ _scala_extension = ".scala" _srcjar_extension = ".srcjar" _empty_coverage_struct = struct( - instrumented_files = None, - providers = [], - replacements = {}, + external = struct( + replacements = {}, + ), + providers_dict = {}, ) def phase_binary_compile(ctx, p): @@ -162,7 +164,7 @@ def _phase_compile( # TODO: simplify the return values and use provider return struct( class_jar = out.class_jar, - coverage = out.coverage, + coverage = out.coverage.external, full_jars = out.full_jars, ijar = out.ijar, ijars = out.ijars, @@ -170,7 +172,9 @@ def _phase_compile( java_jar = out.java_jar, source_jars = _pack_source_jars(ctx) + out.source_jars, merged_provider = out.merged_provider, - external_providers = [out.merged_provider] + out.coverage.providers, + external_providers = dicts.add(out.coverage.providers_dict, { + "JavaInfo": out.merged_provider, + }), ) def _compile_or_empty( @@ -422,8 +426,13 @@ def _jacoco_offline_instrument(ctx, input_jar): extensions = ["scala", "java"], ) return struct( - providers = [provider, instrumented_files_provider], - replacements = replacements, + external = struct( + replacements = replacements, + ), + providers_dict = { + "_CoverageReplacements": provider, + "InstrumentedFilesInfo": instrumented_files_provider, + }, ) def _jacoco_offline_instrument_format_each(in_out_pair): diff --git a/scala/private/phases/phase_default_info.bzl b/scala/private/phases/phase_default_info.bzl index 4f5731515..920977453 100644 --- a/scala/private/phases/phase_default_info.bzl +++ b/scala/private/phases/phase_default_info.bzl @@ -5,32 +5,32 @@ # def phase_binary_default_info(ctx, p): return struct( - external_providers = [ - DefaultInfo( + external_providers = { + "DefaultInfo": DefaultInfo( executable = p.declare_executable, files = depset([p.declare_executable] + p.compile.full_jars), runfiles = p.runfiles.runfiles, ), - ], + }, ) def phase_library_default_info(ctx, p): return struct( - external_providers = [ - DefaultInfo( + external_providers = { + "DefaultInfo": DefaultInfo( files = depset(p.compile.full_jars), runfiles = p.runfiles.runfiles, ), - ], + }, ) def phase_scalatest_default_info(ctx, p): return struct( - external_providers = [ - DefaultInfo( + external_providers = { + "DefaultInfo": DefaultInfo( executable = p.declare_executable, files = depset([p.declare_executable] + p.compile.full_jars), runfiles = ctx.runfiles(p.coverage_runfiles.coverage_runfiles, transitive_files = p.runfiles.runfiles.files), ), - ], + }, ) diff --git a/test/phase/providers/BUILD.bazel b/test/phase/providers/BUILD.bazel index 63966c84c..e7fb39d78 100644 --- a/test/phase/providers/BUILD.bazel +++ b/test/phase/providers/BUILD.bazel @@ -1,4 +1,5 @@ load(":phase_providers_expose.bzl", "phase_expose_provider_singleton", "rule_that_needs_custom_provider", "scala_library_that_exposes_custom_provider") +load(":phase_providers_override.bzl", "phase_override_provider_singleton", "rule_that_has_phases_which_override_providers", "rule_that_verifies_providers_are_overriden") scala_library_that_exposes_custom_provider( name = "scala_library_that_exposes_custom_provider", @@ -13,3 +14,17 @@ phase_expose_provider_singleton( name = "phase_expose_provider_singleton_target", visibility = ["//visibility:public"], ) + +rule_that_has_phases_which_override_providers( + name = "PhaseOverridesProvider", +) + +rule_that_verifies_providers_are_overriden( + name = "PhaseOverridesProviderTest", + dep = ":PhaseOverridesProvider", +) + +phase_override_provider_singleton( + name = "phase_override_provider_singleton_target", + visibility = ["//visibility:public"], +) diff --git a/test/phase/providers/phase_providers_expose.bzl b/test/phase/providers/phase_providers_expose.bzl index 8654b725a..10660c430 100644 --- a/test/phase/providers/phase_providers_expose.bzl +++ b/test/phase/providers/phase_providers_expose.bzl @@ -28,7 +28,7 @@ CustomProviderExposedByPhase = provider() def _phase_expose_provider(ctx, p): return struct( - external_providers = [CustomProviderExposedByPhase()], + external_providers = {"CustomProviderExposedByPhase": CustomProviderExposedByPhase()}, ) def _rule_that_needs_custom_provider_impl(ctx): diff --git a/test/phase/providers/phase_providers_override.bzl b/test/phase/providers/phase_providers_override.bzl new file mode 100644 index 000000000..c432cc55f --- /dev/null +++ b/test/phase/providers/phase_providers_override.bzl @@ -0,0 +1,61 @@ +load("@io_bazel_rules_scala//scala:advanced_usage/providers.bzl", "ScalaRulePhase") +load("@io_bazel_rules_scala//scala:advanced_usage/scala.bzl", "make_scala_library") + +ext_phase_override_provider = { + "phase_providers": [ + "//test/phase/providers:phase_override_provider_singleton_target", + ], +} + +rule_that_has_phases_which_override_providers = make_scala_library(ext_phase_override_provider) + +def _phase_override_provider_singleton_implementation(ctx): + return [ + ScalaRulePhase( + custom_phases = [ + ("last", "", "first_custom", _phase_original), + ("after", "first_custom", "second_custom", _phase_override), + ], + ), + ] + +phase_override_provider_singleton = rule( + implementation = _phase_override_provider_singleton_implementation, +) + +OverrideProvider = provider(fields = ["content"]) + +def _phase_original(ctx, p): + return struct( + external_providers = { + "OverrideProvider": OverrideProvider( + content = "original", + ), + }, + ) + +def _phase_override(ctx, p): + return struct( + external_providers = { + "OverrideProvider": OverrideProvider( + content = "override", + ), + }, + ) + +def _rule_that_verifies_providers_are_overriden_impl(ctx): + if (ctx.attr.dep[OverrideProvider].content != "override"): + fail( + "expected OverrideProvider of {label} to have content 'override' but got '{content}'".format( + label = ctx.label, + content = ctx.attr.dep[OverrideProvider].content, + ), + ) + return [] + +rule_that_verifies_providers_are_overriden = rule( + implementation = _rule_that_verifies_providers_are_overriden_impl, + attrs = { + "dep": attr.label(providers = [OverrideProvider]), + }, +)