diff --git a/scala/private/common.bzl b/scala/private/common.bzl index ca0d5d7d9..030b579db 100644 --- a/scala/private/common.bzl +++ b/scala/private/common.bzl @@ -1,5 +1,6 @@ load("@io_bazel_rules_scala//scala:jars_to_labels.bzl", "JarsToLabelsInfo") load("@io_bazel_rules_scala//scala:plusone.bzl", "PlusOneDeps") +load("@io_bazel_rules_scala//scala:providers.bzl", "ScalaInfo") load("@bazel_skylib//lib:paths.bzl", "paths") def write_manifest_file(actions, output_file, main_class): @@ -22,6 +23,7 @@ def collect_jars( compile_jars = [] runtime_jars = [] deps_providers = [] + macro_classpath = [] for dep_target in dep_targets: # we require a JavaInfo for dependencies @@ -50,11 +52,34 @@ def collect_jars( java_provider.compile_jars.to_list(), ) + # Macros are different from ordinary targets in that they’re used at compile time instead of at runtime. That + # means that both their compile-time classpath and runtime classpath are needed at compile time. We could have + # `scala_macro_library` targets include their runtime dependencies in their compile-time dependencies, but then + # we wouldn't have any guarantees classpath order. + # + # Consider the following scenario. Target A depends on targets B and C. Target C is a macro target, whereas + # target B isn't. Targets C depends on target B. If target A doesn't include the runtime version of target C on + # the compile classpath before the compile (`ijar`d) version of target B that target C depends on, then target A + # won't use the correct version of target B at compile-time when evaluating the macros contained in target C. + # + # For that reason, we opt for a different approach: have `scala_macro_library` targets export `JavaInfo` + # providers as normal, but put their transitive runtime dependencies first on the classpath. Note that we + # shouldn't encounter any issues with external dependencies, so long as they aren't `ijar`d. + if ScalaInfo in dep_target and dep_target[ScalaInfo].contains_macros: + macro_classpath.append(java_provider.transitive_runtime_jars) + + add_labels_of_jars_to( + jars2labels, + dep_target, + [], + java_provider.transitive_runtime_jars.to_list(), + ) + return struct( - compile_jars = depset(transitive = compile_jars), + compile_jars = depset(order = "preorder", transitive = macro_classpath + compile_jars), transitive_runtime_jars = depset(transitive = runtime_jars), jars2labels = JarsToLabelsInfo(jars_to_labels = jars2labels), - transitive_compile_jars = depset(transitive = transitive_compile_jars), + transitive_compile_jars = depset(order = "preorder", transitive = macro_classpath + transitive_compile_jars), deps_providers = deps_providers, ) diff --git a/scala/private/phases/phase_compile.bzl b/scala/private/phases/phase_compile.bzl index d1534aabb..c4f92967f 100644 --- a/scala/private/phases/phase_compile.bzl +++ b/scala/private/phases/phase_compile.bzl @@ -46,17 +46,6 @@ def phase_compile_library_for_plugin_bootstrapping(ctx, p): ) return _phase_compile_default(ctx, p, args) -def phase_compile_macro_library(ctx, p): - args = struct( - buildijar = False, - unused_dependency_checker_ignored_targets = [ - target.label - for target in p.scalac_provider.default_macro_classpath + ctx.attr.exports + - ctx.attr.unused_dependency_checker_ignored_targets - ], - ) - return _phase_compile_default(ctx, p, args) - def phase_compile_junit_test(ctx, p): args = struct( buildijar = False, diff --git a/scala/private/phases/phase_scalainfo_provider.bzl b/scala/private/phases/phase_scalainfo_provider.bzl new file mode 100644 index 000000000..c1264aeab --- /dev/null +++ b/scala/private/phases/phase_scalainfo_provider.bzl @@ -0,0 +1,14 @@ +load("//scala:providers.bzl", "ScalaInfo") + +def _phase_scalainfo_provider_implementation(contains_macros): + return struct( + external_providers = { + "ScalaInfo": ScalaInfo(contains_macros = contains_macros), + }, + ) + +def phase_scalainfo_provider_macro(ctx, p): + return _phase_scalainfo_provider_implementation(contains_macros = True) + +def phase_scalainfo_provider_non_macro(ctx, p): + return _phase_scalainfo_provider_implementation(contains_macros = False) diff --git a/scala/private/phases/phases.bzl b/scala/private/phases/phases.bzl index a497f4d07..295b5ba67 100644 --- a/scala/private/phases/phases.bzl +++ b/scala/private/phases/phases.bzl @@ -7,18 +7,6 @@ load( _extras_phases = "extras_phases", _run_phases = "run_phases", ) -load( - "@io_bazel_rules_scala//scala/private:phases/phase_write_executable.bzl", - _phase_write_executable_common = "phase_write_executable_common", - _phase_write_executable_junit_test = "phase_write_executable_junit_test", - _phase_write_executable_repl = "phase_write_executable_repl", - _phase_write_executable_scalatest = "phase_write_executable_scalatest", -) -load( - "@io_bazel_rules_scala//scala/private:phases/phase_java_wrapper.bzl", - _phase_java_wrapper_common = "phase_java_wrapper_common", - _phase_java_wrapper_repl = "phase_java_wrapper_repl", -) load( "@io_bazel_rules_scala//scala/private:phases/phase_collect_jars.bzl", _phase_collect_jars_common = "phase_collect_jars_common", @@ -27,6 +15,8 @@ load( _phase_collect_jars_repl = "phase_collect_jars_repl", _phase_collect_jars_scalatest = "phase_collect_jars_scalatest", ) +load("@io_bazel_rules_scala//scala/private:phases/phase_collect_exports_jars.bzl", _phase_collect_exports_jars = "phase_collect_exports_jars") +load("@io_bazel_rules_scala//scala/private:phases/phase_collect_srcjars.bzl", _phase_collect_srcjars = "phase_collect_srcjars") load( "@io_bazel_rules_scala//scala/private:phases/phase_compile.bzl", _phase_compile_binary = "phase_compile_binary", @@ -34,39 +24,53 @@ load( _phase_compile_junit_test = "phase_compile_junit_test", _phase_compile_library = "phase_compile_library", _phase_compile_library_for_plugin_bootstrapping = "phase_compile_library_for_plugin_bootstrapping", - _phase_compile_macro_library = "phase_compile_macro_library", _phase_compile_repl = "phase_compile_repl", _phase_compile_scalatest = "phase_compile_scalatest", ) -load( - "@io_bazel_rules_scala//scala/private:phases/phase_runfiles.bzl", - _phase_runfiles_common = "phase_runfiles_common", - _phase_runfiles_library = "phase_runfiles_library", - _phase_runfiles_scalatest = "phase_runfiles_scalatest", -) load( "@io_bazel_rules_scala//scala/private:phases/phase_coverage.bzl", _phase_coverage_common = "phase_coverage_common", _phase_coverage_library = "phase_coverage_library", ) +load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles") +load("@io_bazel_rules_scala//scala/private:phases/phase_declare_executable.bzl", _phase_declare_executable = "phase_declare_executable") load("@io_bazel_rules_scala//scala/private:phases/phase_default_info.bzl", _phase_default_info = "phase_default_info") -load("@io_bazel_rules_scala//scala/private:phases/phase_scalac_provider.bzl", _phase_scalac_provider = "phase_scalac_provider") -load("@io_bazel_rules_scala//scala/private:phases/phase_write_manifest.bzl", _phase_write_manifest = "phase_write_manifest") -load("@io_bazel_rules_scala//scala/private:phases/phase_collect_srcjars.bzl", _phase_collect_srcjars = "phase_collect_srcjars") -load("@io_bazel_rules_scala//scala/private:phases/phase_collect_exports_jars.bzl", _phase_collect_exports_jars = "phase_collect_exports_jars") load( "@io_bazel_rules_scala//scala/private:phases/phase_dependency.bzl", _phase_dependency_common = "phase_dependency_common", _phase_dependency_library_for_plugin_bootstrapping = "phase_dependency_library_for_plugin_bootstrapping", ) -load("@io_bazel_rules_scala//scala/private:phases/phase_declare_executable.bzl", _phase_declare_executable = "phase_declare_executable") -load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_merge_jars = "phase_merge_jars") +load( + "@io_bazel_rules_scala//scala/private:phases/phase_java_wrapper.bzl", + _phase_java_wrapper_common = "phase_java_wrapper_common", + _phase_java_wrapper_repl = "phase_java_wrapper_repl", +) load("@io_bazel_rules_scala//scala/private:phases/phase_jvm_flags.bzl", _phase_jvm_flags = "phase_jvm_flags") +load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_merge_jars = "phase_merge_jars") +load( + "@io_bazel_rules_scala//scala/private:phases/phase_runfiles.bzl", + _phase_runfiles_common = "phase_runfiles_common", + _phase_runfiles_library = "phase_runfiles_library", + _phase_runfiles_scalatest = "phase_runfiles_scalatest", +) +load("@io_bazel_rules_scala//scala/private:phases/phase_scalac_provider.bzl", _phase_scalac_provider = "phase_scalac_provider") load("@io_bazel_rules_scala//scala/private:phases/phase_scalacopts.bzl", _phase_scalacopts = "phase_scalacopts") -load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles") load("@io_bazel_rules_scala//scala/private:phases/phase_scalafmt.bzl", _phase_scalafmt = "phase_scalafmt") -load("@io_bazel_rules_scala//scala/private:phases/phase_test_environment.bzl", _phase_test_environment = "phase_test_environment") +load( + "@io_bazel_rules_scala//scala/private:phases/phase_scalainfo_provider.bzl", + _phase_scalainfo_provider_macro = "phase_scalainfo_provider_macro", + _phase_scalainfo_provider_non_macro = "phase_scalainfo_provider_non_macro", +) load("@io_bazel_rules_scala//scala/private:phases/phase_semanticdb.bzl", _phase_semanticdb = "phase_semanticdb") +load("@io_bazel_rules_scala//scala/private:phases/phase_test_environment.bzl", _phase_test_environment = "phase_test_environment") +load( + "@io_bazel_rules_scala//scala/private:phases/phase_write_executable.bzl", + _phase_write_executable_common = "phase_write_executable_common", + _phase_write_executable_junit_test = "phase_write_executable_junit_test", + _phase_write_executable_repl = "phase_write_executable_repl", + _phase_write_executable_scalatest = "phase_write_executable_scalatest", +) +load("@io_bazel_rules_scala//scala/private:phases/phase_write_manifest.bzl", _phase_write_manifest = "phase_write_manifest") # API run_phases = _run_phases @@ -75,6 +79,10 @@ extras_phases = _extras_phases # scalac_provider phase_scalac_provider = _phase_scalac_provider +# scalainfo_provider +phase_scalainfo_provider_macro = _phase_scalainfo_provider_macro +phase_scalainfo_provider_non_macro = _phase_scalainfo_provider_non_macro + # collect_srcjars phase_collect_srcjars = _phase_collect_srcjars @@ -128,7 +136,6 @@ phase_collect_jars_common = _phase_collect_jars_common phase_compile_binary = _phase_compile_binary phase_compile_library = _phase_compile_library phase_compile_library_for_plugin_bootstrapping = _phase_compile_library_for_plugin_bootstrapping -phase_compile_macro_library = _phase_compile_macro_library phase_compile_junit_test = _phase_compile_junit_test phase_compile_repl = _phase_compile_repl phase_compile_scalatest = _phase_compile_scalatest diff --git a/scala/private/rules/scala_binary.bzl b/scala/private/rules/scala_binary.bzl index 1b6b89b2c..02dffa1e4 100644 --- a/scala/private/rules/scala_binary.bzl +++ b/scala/private/rules/scala_binary.bzl @@ -24,6 +24,7 @@ load( "phase_runfiles_common", "phase_scalac_provider", "phase_scalacopts", + "phase_scalainfo_provider_non_macro", "phase_semanticdb", "phase_write_executable_common", "phase_write_manifest", @@ -36,6 +37,7 @@ def _scala_binary_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_non_macro), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_common), ("collect_jars", phase_collect_jars_common), diff --git a/scala/private/rules/scala_doc.bzl b/scala/private/rules/scala_doc.bzl index e40e6220c..f34c798ec 100644 --- a/scala/private/rules/scala_doc.bzl +++ b/scala/private/rules/scala_doc.bzl @@ -1,10 +1,12 @@ """Scaladoc support""" +load("@io_bazel_rules_scala//scala:providers.bzl", "ScalaInfo") load("@io_bazel_rules_scala//scala/private:common.bzl", "collect_plugin_paths") ScaladocAspectInfo = provider(fields = [ "src_files", #depset[File] "compile_jars", #depset[File] + "macro_classpath", #depset[File] "plugins", #depset[Target] ]) @@ -29,10 +31,15 @@ def _scaladoc_aspect_impl(target, ctx, transitive = True): if hasattr(ctx.rule.attr, "plugins"): plugins = depset(ctx.rule.attr.plugins) + macro_classpath = [] + + for dependency in ctx.rule.attr.deps: + if ScalaInfo in dependency and dependency[ScalaInfo].contains_macros: + macro_classpath.append(dependency[JavaInfo].transitive_runtime_jars) + # Sometimes we only want to generate scaladocs for a single target and not all of its # dependencies transitive_srcs = depset() - transitive_compile_jars = depset() transitive_plugins = depset() if transitive: @@ -40,12 +47,12 @@ def _scaladoc_aspect_impl(target, ctx, transitive = True): if ScaladocAspectInfo in dep: aspec_info = dep[ScaladocAspectInfo] transitive_srcs = aspec_info.src_files - transitive_compile_jars = aspec_info.compile_jars transitive_plugins = aspec_info.plugins return [ScaladocAspectInfo( src_files = depset(transitive = [src_files, transitive_srcs]), - compile_jars = depset(transitive = [compile_jars, transitive_compile_jars]), + compile_jars = depset(transitive = [compile_jars]), + macro_classpath = depset(transitive = macro_classpath), plugins = depset(transitive = [plugins, transitive_plugins]), )] @@ -73,11 +80,15 @@ def _scala_doc_impl(ctx): src_files = depset(transitive = [dep[ScaladocAspectInfo].src_files for dep in ctx.attr.deps]) compile_jars = depset(transitive = [dep[ScaladocAspectInfo].compile_jars for dep in ctx.attr.deps]) + # See the documentation for `collect_jars` in `scala/private/common.bzl` to understand why this is prepended to the + # classpath + macro_classpath = depset(transitive = [dep[ScaladocAspectInfo].macro_classpath for dep in ctx.attr.deps]) + # Get the 'real' paths to the plugin jars. plugins = collect_plugin_paths(depset(transitive = [dep[ScaladocAspectInfo].plugins for dep in ctx.attr.deps]).to_list()) # Construct the full classpath depset since we need to add compiler plugins too. - classpath = depset(transitive = [plugins, compile_jars]) + classpath = depset(transitive = [macro_classpath, plugins, compile_jars]) # Construct scaladoc args, which also include scalac args. # See `scaladoc -help` for more information. diff --git a/scala/private/rules/scala_junit_test.bzl b/scala/private/rules/scala_junit_test.bzl index 9ebc15714..f5f603fcd 100644 --- a/scala/private/rules/scala_junit_test.bzl +++ b/scala/private/rules/scala_junit_test.bzl @@ -25,6 +25,7 @@ load( "phase_runfiles_common", "phase_scalac_provider", "phase_scalacopts", + "phase_scalainfo_provider_non_macro", "phase_semanticdb", "phase_test_environment", "phase_write_executable_junit_test", @@ -42,6 +43,7 @@ def _scala_junit_test_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_non_macro), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_common), ("collect_jars", phase_collect_jars_junit_test), diff --git a/scala/private/rules/scala_library.bzl b/scala/private/rules/scala_library.bzl index f172787f5..64c08d1a2 100644 --- a/scala/private/rules/scala_library.bzl +++ b/scala/private/rules/scala_library.bzl @@ -25,7 +25,6 @@ load( "phase_collect_srcjars", "phase_compile_library", "phase_compile_library_for_plugin_bootstrapping", - "phase_compile_macro_library", "phase_coverage_common", "phase_coverage_library", "phase_default_info", @@ -35,6 +34,8 @@ load( "phase_runfiles_library", "phase_scalac_provider", "phase_scalacopts", + "phase_scalainfo_provider_macro", + "phase_scalainfo_provider_non_macro", "phase_semanticdb", "phase_write_manifest", "run_phases", @@ -63,6 +64,7 @@ def _scala_library_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_non_macro), ("collect_srcjars", phase_collect_srcjars), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_common), @@ -151,6 +153,7 @@ def _scala_library_for_plugin_bootstrapping_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_non_macro), ("collect_srcjars", phase_collect_srcjars), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_library_for_plugin_bootstrapping), @@ -226,13 +229,14 @@ def _scala_macro_library_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_macro), ("collect_srcjars", phase_collect_srcjars), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_common), ("collect_jars", phase_collect_jars_macro_library), ("scalacopts", phase_scalacopts), ("semanticdb", phase_semanticdb), - ("compile", phase_compile_macro_library), + ("compile", phase_compile_library), ("coverage", phase_coverage_common), ("merge_jars", phase_merge_jars), ("runfiles", phase_runfiles_library), diff --git a/scala/private/rules/scala_repl.bzl b/scala/private/rules/scala_repl.bzl index bae669598..b76cfa075 100644 --- a/scala/private/rules/scala_repl.bzl +++ b/scala/private/rules/scala_repl.bzl @@ -24,6 +24,7 @@ load( "phase_runfiles_common", "phase_scalac_provider", "phase_scalacopts", + "phase_scalainfo_provider_non_macro", "phase_semanticdb", "phase_write_executable_repl", "phase_write_manifest", @@ -36,6 +37,7 @@ def _scala_repl_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_non_macro), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_common), # need scala-compiler for MainGenericRunner below diff --git a/scala/private/rules/scala_test.bzl b/scala/private/rules/scala_test.bzl index 1fe3992a8..b8e6a87a8 100644 --- a/scala/private/rules/scala_test.bzl +++ b/scala/private/rules/scala_test.bzl @@ -25,6 +25,7 @@ load( "phase_runfiles_scalatest", "phase_scalac_provider", "phase_scalacopts", + "phase_scalainfo_provider_non_macro", "phase_semanticdb", "phase_test_environment", "phase_write_executable_scalatest", @@ -38,6 +39,7 @@ def _scala_test_impl(ctx): # customizable phases [ ("scalac_provider", phase_scalac_provider), + ("scalainfo_provider", phase_scalainfo_provider_non_macro), ("write_manifest", phase_write_manifest), ("dependency", phase_dependency_common), ("collect_jars", phase_collect_jars_scalatest), diff --git a/scala/providers.bzl b/scala/providers.bzl index 6ebcc0e90..f0f176a55 100644 --- a/scala/providers.bzl +++ b/scala/providers.bzl @@ -1,12 +1,3 @@ -ScalacProvider = provider( - doc = "ScalacProvider", - fields = [ - "default_classpath", - "default_macro_classpath", - "default_repl_classpath", - ], -) - DepsInfo = provider( doc = "Defines depset required by rules", fields = { @@ -30,3 +21,19 @@ declare_deps_provider = rule( "deps_id": attr.string(mandatory = True), }, ) + +ScalacProvider = provider( + doc = "ScalacProvider", + fields = [ + "default_classpath", + "default_macro_classpath", + "default_repl_classpath", + ], +) + +ScalaInfo = provider( + doc = "Contains information about Scala targets.", + fields = { + "contains_macros": "Whether this target contains macros.", + }, +) diff --git a/test/macros/BUILD b/test/macros/BUILD index b293ff61d..e070d65de 100644 --- a/test/macros/BUILD +++ b/test/macros/BUILD @@ -22,3 +22,27 @@ scala_library( srcs = ["MacroUser.scala"], deps = [":correct-macro"], ) + +scala_library( + name = "macro-dependency", + srcs = ["MacroDependency.scala"], +) + +scala_macro_library( + name = "macro-with-dependencies", + srcs = ["MacroWithDependencies.scala"], + deps = [":macro-dependency"], +) + +scala_library( + name = "macro-with-dependencies-user", + srcs = ["MacroWithDependenciesUser.scala"], + # Without this, `:macro-dependency` will be flagged as an unused dependency. But we want to test that despite it + # appearing before `:macro-with-dependencies` in `deps`, its runtime JAR is included before its compile JAR in the + # compile classpath + unused_dependency_checker_mode = "off", + deps = [ + ":macro-dependency", + ":macro-with-dependencies", + ], +) diff --git a/test/macros/MacroDependency.scala b/test/macros/MacroDependency.scala new file mode 100644 index 000000000..2c804cf1f --- /dev/null +++ b/test/macros/MacroDependency.scala @@ -0,0 +1,5 @@ +package macros + +object MacroDependency { + def isEven(number: Int): Boolean = number % 2 == 0 +} diff --git a/test/macros/MacroWithDependencies.scala b/test/macros/MacroWithDependencies.scala new file mode 100644 index 000000000..63b7d701a --- /dev/null +++ b/test/macros/MacroWithDependencies.scala @@ -0,0 +1,18 @@ +package macros + +import scala.language.experimental.macros +import scala.reflect.macros.blackbox + +object MacroWithDependencies { + def isEvenMacro(number: Int): Boolean = macro isEvenMacroImpl + def isEvenMacroImpl(context: blackbox.Context)(number: context.Expr[Int]): context.Expr[Boolean] = { + import context.universe._ + + val value = number.tree match { + case Literal(Constant(value: Int)) => value + case _ => throw new Exception(s"Expected ${number.tree} to be a literal.") + } + + context.Expr(Literal(Constant(MacroDependency.isEven(value)))) + } +} diff --git a/test/macros/MacroWithDependenciesUser.scala b/test/macros/MacroWithDependenciesUser.scala new file mode 100644 index 000000000..3bd980ebc --- /dev/null +++ b/test/macros/MacroWithDependenciesUser.scala @@ -0,0 +1,9 @@ +package macros + +object MacroWithDependenciesUser { + def main(arguments: Array[String]): Unit = { + println(s"0 is even via macro: ${MacroWithDependencies.isEvenMacro(0)}") + println(s"1 is even via macro: ${MacroWithDependencies.isEvenMacro(1)}") + println(s"1 + 1 is even macro: ${MacroWithDependencies.isEvenMacro(1 + 1)}") + } +} diff --git a/test/shell/test_macros.sh b/test/shell/test_macros.sh index 468c87bca..074a9bb6c 100755 --- a/test/shell/test_macros.sh +++ b/test/shell/test_macros.sh @@ -13,5 +13,10 @@ correct_macro_user_builds() { bazel build //test/macros:correct-macro-user } +macros_can_have_dependencies() { + bazel build //test/macros:macro-with-dependencies-user +} + $runner incorrect_macro_user_does_not_build $runner correct_macro_user_builds +$runner macros_can_have_dependencies