Skip to content

refactor(test): execute all #[rustup_macros::unit_test]s within a tokio context #3868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 41 additions & 66 deletions rustup-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ pub fn integration_test(
.into()
}

/// Custom wrapper macro around `#[test]` and `#[tokio::test]` for unit tests.
/// Custom wrapper macro around `#[tokio::test]` for unit tests.
///
/// Calls `rustup::test::before_test()` before the test body, and
/// `rustup::test::after_test()` after, even in the event of an unwinding panic.
/// For async functions calls the async variants of these functions.
///
/// This wrapper makes the underlying test function async even if it's sync in nature.
/// This ensures that a [`tokio`] runtime is always present during tests,
/// making it easier to setup [`tracing`] subscribers
/// (e.g. [`opentelemetry_otlp::OtlpTracePipeline`] always requires a [`tokio`] runtime to be
/// installed).
#[proc_macro_attribute]
pub fn unit_test(
args: proc_macro::TokenStream,
Expand Down Expand Up @@ -77,74 +82,44 @@ pub fn unit_test(
.into()
}

// False positive from clippy :/
#[allow(clippy::redundant_clone)]
fn test_inner(mod_path: String, mut input: ItemFn) -> syn::Result<TokenStream> {
if input.sig.asyncness.is_some() {
let before_ident = format!("{}::before_test_async", mod_path);
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
let after_ident = format!("{}::after_test_async", mod_path);
let after_ident = syn::parse_str::<Expr>(&after_ident)?;

let inner = input.block;
let name = input.sig.ident.clone();
let new_block: Block = parse_quote! {
{
#before_ident().await;
// Define a function with same name we can instrument inside the
// tracing enablement logic.
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
async fn #name() { #inner }
// Thunk through a new thread to permit catching the panic
// without grabbing the entire state machine defined by the
// outer test function.
let result = ::std::panic::catch_unwind(||{
let handle = tokio::runtime::Handle::current().clone();
::std::thread::spawn(move || handle.block_on(#name())).join().unwrap()
});
#after_ident().await;
match result {
Ok(result) => result,
Err(err) => ::std::panic::resume_unwind(err)
}
}
};
// Make the test function async even if it's sync.
input.sig.asyncness.get_or_insert_with(Default::default);

input.block = Box::new(new_block);
let before_ident = format!("{}::before_test_async", mod_path);
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
let after_ident = format!("{}::after_test_async", mod_path);
let after_ident = syn::parse_str::<Expr>(&after_ident)?;

Ok(quote! {
let inner = input.block;
let name = input.sig.ident.clone();
let new_block: Block = parse_quote! {
{
let _guard = #before_ident().await;
// Define a function with same name we can instrument inside the
// tracing enablement logic.
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
#[::tokio::test(flavor = "multi_thread", worker_threads = 1)]
#input
})
} else {
let before_ident = format!("{}::before_test", mod_path);
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
let after_ident = format!("{}::after_test", mod_path);
let after_ident = syn::parse_str::<Expr>(&after_ident)?;

let inner = input.block;
let name = input.sig.ident.clone();
let new_block: Block = parse_quote! {
{
#before_ident();
// Define a function with same name we can instrument inside the
// tracing enablement logic.
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
fn #name() { #inner }
let result = ::std::panic::catch_unwind(#name);
#after_ident();
match result {
Ok(result) => result,
Err(err) => ::std::panic::resume_unwind(err)
}
async fn #name() { #inner }
// Thunk through a new thread to permit catching the panic
// without grabbing the entire state machine defined by the
// outer test function.
let result = ::std::panic::catch_unwind(||{
let handle = tokio::runtime::Handle::current().clone();
::std::thread::spawn(move || handle.block_on(#name())).join().unwrap()
});
#after_ident().await;
match result {
Ok(result) => result,
Err(err) => ::std::panic::resume_unwind(err)
}
};
}
};

input.block = Box::new(new_block);
Ok(quote! {
#[::std::prelude::v1::test]
#input
})
}
input.block = Box::new(new_block);

Ok(quote! {
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
#[::tokio::test(flavor = "multi_thread", worker_threads = 1)]
#input
})
}
37 changes: 10 additions & 27 deletions src/bin/rustup-init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,17 @@ async fn maybe_trace_rustup() -> Result<utils::ExitCode> {
}
#[cfg(feature = "otel")]
{
use std::time::Duration;

use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{
propagation::TraceContextPropagator,
trace::{self, Sampler},
Resource,
use tracing_subscriber::{layer::SubscriberExt, Registry};

let telemetry = {
use opentelemetry::global;
use opentelemetry_sdk::propagation::TraceContextPropagator;

global::set_text_map_propagator(TraceContextPropagator::new());
rustup::cli::log::telemetry()
};
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};

global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_timeout(Duration::from_secs(3)),
)
.with_trace_config(
trace::config()
.with_sampler(Sampler::AlwaysOn)
.with_resource(Resource::new(vec![KeyValue::new("service.name", "rustup")])),
)
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
let env_filter = EnvFilter::try_from_default_env().unwrap_or(EnvFilter::new("INFO"));
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let subscriber = Registry::default().with(env_filter).with(telemetry);

let subscriber = Registry::default().with(telemetry);
tracing::subscriber::set_global_default(subscriber)?;
let result = run_rustup().await;
// We're tracing, so block until all spans are exported.
Expand Down
57 changes: 55 additions & 2 deletions src/cli/log.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
use std::fmt;
use std::io::Write;
use std::{fmt, io::Write};

#[cfg(feature = "otel")]
use once_cell::sync::Lazy;
#[cfg(feature = "otel")]
use opentelemetry_sdk::trace::Tracer;
#[cfg(feature = "otel")]
use tracing::Subscriber;
#[cfg(feature = "otel")]
use tracing_subscriber::{registry::LookupSpan, EnvFilter, Layer};

use crate::currentprocess::{process, terminalsource};

Expand Down Expand Up @@ -71,3 +79,48 @@ pub(crate) fn debug_fmt(args: fmt::Arguments<'_>) {
let _ = writeln!(t.lock());
}
}

/// A [`tracing::Subscriber`] [`Layer`][`tracing_subscriber::Layer`] that corresponds to Rustup's
/// optional `opentelemetry` (a.k.a. `otel`) feature.
#[cfg(feature = "otel")]
pub fn telemetry<S>() -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
// NOTE: This reads from the real environment variables instead of `process().var_os()`.
let env_filter = EnvFilter::try_from_default_env().unwrap_or(EnvFilter::new("INFO"));
tracing_opentelemetry::layer()
.with_tracer(TELEMETRY_DEFAULT_TRACER.clone())
.with_filter(env_filter)
}

/// The default `opentelemetry` tracer used across Rustup.
///
/// # Note
/// The initializer function will panic if not called within the context of a [`tokio`] runtime.
#[cfg(feature = "otel")]
static TELEMETRY_DEFAULT_TRACER: Lazy<Tracer> = Lazy::new(|| {
use std::time::Duration;

use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{
trace::{self, Sampler},
Resource,
};

opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_timeout(Duration::from_secs(3)),
)
.with_trace_config(
trace::config()
.with_sampler(Sampler::AlwaysOn)
.with_resource(Resource::new(vec![KeyValue::new("service.name", "rustup")])),
)
.install_batch(opentelemetry_sdk::runtime::Tokio)
.expect("error installing `OtlpTracePipeline` in the current `tokio` runtime")
});
85 changes: 16 additions & 69 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,85 +224,32 @@ where
f(&rustup_home)
}

#[cfg(feature = "otel")]
use once_cell::sync::Lazy;

/// A tokio runtime for the sync tests, permitting the use of tracing. This is
/// never shutdown, instead it is just dropped at end of process.
#[cfg(feature = "otel")]
static TRACE_RUNTIME: Lazy<tokio::runtime::Runtime> =
Lazy::new(|| tokio::runtime::Runtime::new().unwrap());
/// A tracer for the tests.
#[cfg(feature = "otel")]
static TRACER: Lazy<opentelemetry_sdk::trace::Tracer> = Lazy::new(|| {
use std::time::Duration;

use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{
propagation::TraceContextPropagator,
trace::{self, Sampler},
Resource,
};
use tokio::runtime::Handle;
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};

// Use the current runtime, or the sync test runtime otherwise.
let handle = match Handle::try_current() {
Ok(handle) => handle,
Err(_) => TRACE_RUNTIME.handle().clone(),
};
let _guard = handle.enter();

let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_timeout(Duration::from_secs(3)),
)
.with_trace_config(
trace::config()
.with_sampler(Sampler::AlwaysOn)
.with_resource(Resource::new(vec![KeyValue::new("service.name", "rustup")])),
)
.install_batch(opentelemetry_sdk::runtime::Tokio)
.unwrap();

global::set_text_map_propagator(TraceContextPropagator::new());
let env_filter = EnvFilter::try_from_default_env().unwrap_or(EnvFilter::new("INFO"));
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer.clone());
let subscriber = Registry::default().with(env_filter).with(telemetry);
tracing::subscriber::set_global_default(subscriber).unwrap();
tracer
});

pub fn before_test() {
pub async fn before_test_async() -> Option<tracing::dispatcher::DefaultGuard> {
#[cfg(feature = "otel")]
{
Lazy::force(&TRACER);
}
}
use tracing_subscriber::{layer::SubscriberExt, Registry};

pub async fn before_test_async() {
#[cfg(feature = "otel")]
{
Lazy::force(&TRACER);
}
}
let telemetry = {
use opentelemetry::global;
use opentelemetry_sdk::propagation::TraceContextPropagator;

pub fn after_test() {
#[cfg(feature = "otel")]
global::set_text_map_propagator(TraceContextPropagator::new());
crate::cli::log::telemetry()
};

let subscriber = Registry::default().with(telemetry);
Some(tracing::subscriber::set_default(subscriber))
}
#[cfg(not(feature = "otel"))]
{
let handle = TRACE_RUNTIME.handle();
let _guard = handle.enter();
TRACER.provider().map(|p| p.force_flush());
None
}
}

pub async fn after_test_async() {
#[cfg(feature = "otel")]
{
TRACER.provider().map(|p| p.force_flush());
// We're tracing, so block until all spans are exported.
opentelemetry::global::shutdown_tracer_provider();
}
}
Loading
Loading