diff --git a/library/core/src/error.rs b/library/core/src/error.rs index 9dbea57fa1f86..49fd3bacf77c8 100644 --- a/library/core/src/error.rs +++ b/library/core/src/error.rs @@ -650,6 +650,68 @@ impl<'a> Request<'a> { self.provide_with::>>(fulfil) } + /// Provides a mutable reference. The referee type must be bounded by `'static`, + /// but may be unsized. + /// + /// # Examples + /// + /// Provides a mutable reference to a field as a `&mut str`. + /// + /// ```rust + /// #![feature(context_provider)] + /// + /// use core::error::Request; + /// + /// #[derive(Debug)] + /// struct SomeConcreteType { field: String } + /// + /// impl std::task::Provider for SomeConcreteType { + /// fn provide_mut<'a>(&'a mut self, request: &mut Request<'a>) { + /// request.provide_mut::(&mut self.field); + /// } + /// } + /// ``` + #[unstable(feature = "context_provider", issue = "none")] + pub fn provide_mut(&mut self, value: &'a mut T) -> &mut Self { + self.provide::>>(value) + } + + /// Provides a mutable reference computed using a closure. The referee type + /// must be bounded by `'static`, but may be unsized. + /// + /// # Examples + /// + /// Provides a reference to a field as a `&mut str`. + /// + /// ```rust + /// #![feature(context_provider)] + /// + /// use core::error::Request; + /// + /// #[derive(Debug)] + /// struct SomeConcreteType { business: String, party: String } + /// fn today_is_a_weekday() -> bool { true } + /// + /// impl std::task::Provider for SomeConcreteType { + /// fn provide_mut<'a>(&'a mut self, request: &mut Request<'a>) { + /// request.provide_mut_with::(|| { + /// if today_is_a_weekday() { + /// &mut self.business + /// } else { + /// &mut self.party + /// } + /// }); + /// } + /// } + /// ``` + #[unstable(feature = "context_provider", issue = "none")] + pub fn provide_mut_with( + &mut self, + fulfil: impl FnOnce() -> &'a mut T, + ) -> &mut Self { + self.provide_with::>>(fulfil) + } + /// Provides a value with the given `Type` tag. fn provide(&mut self, value: I::Reified) -> &mut Self where @@ -922,6 +984,15 @@ pub(crate) mod tags { impl<'a, I: MaybeSizedType<'a>> Type<'a> for Ref { type Reified = &'a I::Reified; } + + /// Type-based tag for mutable reference types (`&'a mut T`, where T is represented by + /// `>::Reified`. + #[derive(Debug)] + pub(crate) struct RefMut(PhantomData); + + impl<'a, I: MaybeSizedType<'a>> Type<'a> for RefMut { + type Reified = &'a mut I::Reified; + } } /// An `Option` with a type tag `I`. @@ -948,9 +1019,9 @@ unsafe trait Erased<'a>: 'a {} unsafe impl<'a, I: tags::Type<'a>> Erased<'a> for TaggedOption<'a, I> {} -struct Tagged { - tag_id: TypeId, - value: E, +pub(crate) struct Tagged { + pub tag_id: TypeId, + pub value: E, } impl<'a> Tagged + 'a> { diff --git a/library/core/src/task/mod.rs b/library/core/src/task/mod.rs index f1a789e32a7a7..a48e610102840 100644 --- a/library/core/src/task/mod.rs +++ b/library/core/src/task/mod.rs @@ -8,7 +8,9 @@ pub use self::poll::Poll; mod wake; #[stable(feature = "futures_api", since = "1.36.0")] -pub use self::wake::{Context, ContextBuilder, LocalWaker, RawWaker, RawWakerVTable, Waker}; +pub use self::wake::{ + Context, ContextBuilder, LocalWaker, Provider, RawWaker, RawWakerVTable, Waker, +}; mod ready; #[stable(feature = "ready_macro", since = "1.64.0")] diff --git a/library/core/src/task/wake.rs b/library/core/src/task/wake.rs index ba429005fab3d..defe40d9371ff 100644 --- a/library/core/src/task/wake.rs +++ b/library/core/src/task/wake.rs @@ -1,6 +1,7 @@ #![stable(feature = "futures_api", since = "1.36.0")] -use crate::any::Any; +use crate::any::{Any, TypeId}; +use crate::error::{Request, Tagged, TaggedOption, tags}; use crate::marker::PhantomData; use crate::mem::{ManuallyDrop, transmute}; use crate::panic::AssertUnwindSafe; @@ -211,6 +212,155 @@ enum ExtData<'a> { None(()), } +/// `Provider` is a trait that allows querying for arbitrary context data. +#[unstable(feature = "context_provider", issue = "none")] +pub trait Provider: fmt::Debug { + /// Provides type-based access to additional context data. + /// + /// Used in conjunction with [`Request::provide_value`] and [`Request::provide_ref`] to extract + /// references to member variables from `dyn Provider` trait objects. + /// + /// # Example + /// + /// ```rust + /// #![feature(error_generic_member_access)] + /// #![feature(context_provider)] + /// #![feature(noop_waker)] + /// use core::fmt; + /// + /// #[derive(Debug)] + /// enum MyLittleTeaPot { + /// Empty, + /// } + /// + /// #[derive(Debug)] + /// struct MyExtensionData { + /// // ... + /// } + /// + /// impl MyExtensionData { + /// fn new() -> MyExtensionData { + /// // ... + /// # MyExtensionData {} + /// } + /// } + /// + /// #[derive(Debug)] + /// struct MyProvider { + /// ext: MyExtensionData, + /// } + /// + /// impl std::task::Provider for MyProvider { + /// fn provide<'a>(&'a self, request: &mut Request<'a>) { + /// request + /// .provide_ref::(&self.ext); + /// } + /// } + /// + /// fn main() { + /// let ext = MyExtensionData::new(); + /// let mut provider = MyProvider { ext }; + /// let ext_ref_orig = &provider.ext as *const MyExtensionData; + /// let dyn_provider = &mut provider as &mut dyn std::task::Provider; + /// let cx = std::task::ContextBuilder::from_waker(Waker::noop()).provider(dyn_provider).build(); + /// let ext_ref = cx.request_ref::().unwrap(); + /// + /// assert!(core::ptr::eq(ext_ref_orig, ext_ref)); + /// assert!(cx.request_ref::().is_none()); + /// } + /// ``` + #[allow(unused_variables)] + fn provide<'a>(&'a self, request: &mut Request<'a>) {} + + /// Provides type-based access to additional mutable context data. + /// + /// Used in conjunction with [`Request::provide_mut`] to extract + /// mutable references to member variables from `dyn Provider` trait objects. + /// + /// # Example + /// + /// ```rust + /// #![feature(error_generic_member_access)] + /// #![feature(context_provider)] + /// #![feature(noop_waker)] + /// use core::fmt; + /// + /// #[derive(Debug)] + /// enum MyLittleTeaPot { + /// Empty, + /// } + /// + /// #[derive(Debug)] + /// struct MyExtensionData { + /// // ... + /// } + /// + /// impl MyExtensionData { + /// fn new() -> MyExtensionData { + /// // ... + /// # MyExtensionData {} + /// } + /// } + /// + /// #[derive(Debug)] + /// struct MyProvider { + /// ext: MyExtensionData, + /// } + /// + /// impl std::task::Provider for MyProvider { + /// fn provide_mut<'a>(&'a mut self, request: &mut Request<'a>) { + /// request + /// .provide_mut::(&mut self.ext); + /// } + /// } + /// + /// fn main() { + /// let ext = MyExtensionData::new(); + /// let mut provider = MyProvider { ext }; + /// let ext_ref_orig = &mut provider.ext as *mut MyExtensionData; + /// let dyn_provider = &mut provider as &mut dyn std::task::Provider; + /// let cx = std::task::ContextBuilder::from_waker(Waker::noop()).provider(dyn_provider).build(); + /// let ext_ref = cx.request_mut::().unwrap(); + /// + /// assert!(core::ptr::eq(ext_ref_orig, ext_ref)); + /// assert!(cx.request_mut::().is_none()); + /// } + /// ``` + #[allow(unused_variables)] + fn provide_mut<'a>(&'a mut self, request: &mut Request<'a>) {} +} + +#[unstable(feature = "context_provider", issue = "none")] +impl Provider for &mut (dyn Provider + 'static) { + fn provide<'a>(&'a self, request: &mut Request<'a>) { + Provider::provide(*self, request) + } + + fn provide_mut<'a>(&'a mut self, request: &mut Request<'a>) { + Provider::provide_mut(*self, request) + } +} + +/// Request a specific value by tag from the `Provider`. +fn request_by_type_tag<'a, I>(provider: &'a (impl Provider + ?Sized)) -> Option +where + I: tags::Type<'a>, +{ + let mut tagged = Tagged { tag_id: TypeId::of::(), value: TaggedOption::<'a, I>(None) }; + provider.provide(tagged.as_request()); + tagged.value.0 +} + +/// Request a specific value by tag from the `Provider`. +fn request_by_type_tag_mut<'a, I>(provider: &'a mut (impl Provider + ?Sized)) -> Option +where + I: tags::Type<'a>, +{ + let mut tagged = Tagged { tag_id: TypeId::of::(), value: TaggedOption::<'a, I>(None) }; + provider.provide_mut(tagged.as_request()); + tagged.value.0 +} + /// The context of an asynchronous task. /// /// Currently, `Context` only serves to provide access to a [`&Waker`](Waker) @@ -220,6 +370,8 @@ enum ExtData<'a> { pub struct Context<'a> { waker: &'a Waker, local_waker: &'a LocalWaker, + provider: Option<&'a mut (dyn Provider + 'static)>, + parent_provider: Option<&'a mut (dyn Provider + 'static)>, ext: AssertUnwindSafe>, // Ensure we future-proof against variance changes by forcing // the lifetime to be invariant (argument-position lifetimes @@ -257,6 +409,96 @@ impl<'a> Context<'a> { &self.local_waker } + /// Requests a value of type `T` from the context's provider, if available. + /// + /// # Examples + /// + /// Get a string value from a context. + /// + /// ```rust + /// #![feature(context_provider)] + /// use core::task::Context; + /// + /// fn get_string(cx: &Context) -> String { + /// cx.request_value::().unwrap() + /// } + /// ``` + #[unstable(feature = "context_provider", issue = "none")] + pub fn request_value(&self) -> Option + where + T: 'static, + { + let Some(provider) = &self.provider else { return None }; + + request_by_type_tag::>(provider).or_else(|| { + if let Some(parent) = &self.parent_provider { + request_by_type_tag::>(parent) + } else { + None + } + }) + } + + /// Requests a reference of type `T` from the context's provider, if available. + /// + /// # Examples + /// + /// Get a string reference from a context. + /// + /// ```rust + /// #![feature(context_provider)] + /// use core::task::Context; + /// + /// fn get_str(cx: &Context) -> &str { + /// cx.request_ref::().unwrap() + /// } + /// ``` + #[unstable(feature = "context_provider", issue = "none")] + pub fn request_ref(&self) -> Option<&T> + where + T: 'static + ?Sized, + { + let Some(provider) = &self.provider else { return None }; + + request_by_type_tag::>>(provider).or_else(|| { + if let Some(parent) = &self.parent_provider { + request_by_type_tag::>>(parent) + } else { + None + } + }) + } + + /// Requests a mutable reference of type `T` from the context's provider, if available. + /// + /// # Examples + /// + /// Get a mutable string reference from a context. + /// + /// ```rust + /// #![feature(context_provider)] + /// use core::task::Context; + /// + /// fn get_str(cx: &mut Context) -> &mut str { + /// cx.request_mut::().unwrap() + /// } + /// ``` + #[unstable(feature = "context_provider", issue = "none")] + pub fn request_mut(&mut self) -> Option<&mut T> + where + T: 'static + ?Sized, + { + let Some(provider) = &mut self.provider else { return None }; + + request_by_type_tag_mut::>>(provider).or_else(|| { + if let Some(parent) = &mut self.parent_provider { + request_by_type_tag_mut::>>(parent) + } else { + None + } + }) + } + /// Returns a reference to the extension data for the current task. #[inline] #[unstable(feature = "context_ext", issue = "123392")] @@ -303,6 +545,8 @@ impl fmt::Debug for Context<'_> { pub struct ContextBuilder<'a> { waker: &'a Waker, local_waker: &'a LocalWaker, + provider: Option<&'a mut (dyn Provider + 'static)>, + parent_provider: Option<&'a mut (dyn Provider + 'static)>, ext: ExtData<'a>, // Ensure we future-proof against variance changes by forcing // the lifetime to be invariant (argument-position lifetimes @@ -324,6 +568,8 @@ impl<'a> ContextBuilder<'a> { Self { waker, local_waker, + provider: None, + parent_provider: None, ext: ExtData::None(()), _marker: PhantomData, _marker2: PhantomData, @@ -333,7 +579,7 @@ impl<'a> ContextBuilder<'a> { /// Creates a ContextBuilder from an existing Context. #[inline] #[unstable(feature = "context_ext", issue = "123392")] - pub const fn from(cx: &'a mut Context<'_>) -> Self { + pub fn from(cx: &'a mut Context<'_>) -> Self { let ext = match &mut cx.ext.0 { ExtData::Some(ext) => ExtData::Some(*ext), ExtData::None(()) => ExtData::None(()), @@ -341,6 +587,8 @@ impl<'a> ContextBuilder<'a> { Self { waker: cx.waker, local_waker: cx.local_waker, + provider: None, + parent_provider: cx.provider.as_deref_mut(), ext, _marker: PhantomData, _marker2: PhantomData, @@ -361,6 +609,13 @@ impl<'a> ContextBuilder<'a> { Self { local_waker, ..self } } + /// Sets the value for the provider on `Context`. + #[inline] + #[unstable(feature = "context_provider", issue = "none")] + pub const fn provider(self, provider: &'a mut (dyn Provider + 'static)) -> Self { + Self { provider: Some(provider), ..self } + } + /// Sets the value for the extension data on `Context`. #[inline] #[unstable(feature = "context_ext", issue = "123392")] @@ -372,8 +627,24 @@ impl<'a> ContextBuilder<'a> { #[inline] #[unstable(feature = "local_waker", issue = "118959")] pub const fn build(self) -> Context<'a> { - let ContextBuilder { waker, local_waker, ext, _marker, _marker2 } = self; - Context { waker, local_waker, ext: AssertUnwindSafe(ext), _marker, _marker2 } + let ContextBuilder { + waker, + local_waker, + provider, + parent_provider, + ext, + _marker, + _marker2, + } = self; + Context { + waker, + local_waker, + provider, + parent_provider, + ext: AssertUnwindSafe(ext), + _marker, + _marker2, + } } } diff --git a/library/core/tests/lib.rs b/library/core/tests/lib.rs index 18feee9fb2545..0e2943ace0a2a 100644 --- a/library/core/tests/lib.rs +++ b/library/core/tests/lib.rs @@ -17,6 +17,7 @@ #![feature(const_eval_select)] #![feature(const_swap_nonoverlapping)] #![feature(const_trait_impl)] +#![feature(context_provider)] #![feature(core_intrinsics)] #![feature(core_intrinsics_fallbacks)] #![feature(core_io_borrowed_buf)] @@ -57,6 +58,7 @@ #![feature(iterator_try_reduce)] #![feature(layout_for_ptr)] #![feature(lazy_get)] +#![feature(local_waker)] #![feature(maybe_uninit_fill)] #![feature(maybe_uninit_uninit_array_transpose)] #![feature(maybe_uninit_write_slice)] diff --git a/library/core/tests/waker.rs b/library/core/tests/waker.rs index 4889b8959ece4..17b2dbe3511cf 100644 --- a/library/core/tests/waker.rs +++ b/library/core/tests/waker.rs @@ -1,5 +1,6 @@ +use std::error::Request; use std::ptr; -use std::task::{RawWaker, RawWakerVTable, Waker}; +use std::task::{ContextBuilder, Provider, RawWaker, RawWakerVTable, Waker}; #[test] fn test_waker_getters() { @@ -13,6 +14,40 @@ fn test_waker_getters() { assert!(ptr::eq(waker2.vtable(), &WAKER_VTABLE)); } +// Test the `Request` API. +#[derive(Debug)] +struct SomeConcreteType { + some_string: String, +} + +impl Provider for SomeConcreteType { + fn provide<'a>(&'a self, request: &mut Request<'a>) { + request + .provide_ref::(&self.some_string) + .provide_ref::(&self.some_string) + .provide_value_with::(|| "bye".to_owned()); + } + + fn provide_mut<'a>(&'a mut self, request: &mut Request<'a>) { + request.provide_mut::(&mut self.some_string); + } +} + +#[test] +fn test_context_provider() { + let obj = &mut SomeConcreteType { some_string: "hello".to_owned() }; + let builder = ContextBuilder::from_waker(Waker::noop()).provider(obj); + let mut cx = builder.build(); + + assert_eq!(cx.request_ref::().unwrap(), "hello"); + assert_eq!(cx.request_value::().unwrap(), "bye"); + assert_eq!(cx.request_value::(), None); + + cx.request_mut::().unwrap().push_str(" world"); + + assert_eq!(obj.some_string, "hello world"); +} + static WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( |data| RawWaker::new(ptr::without_provenance_mut(data as usize + 1), &WAKER_VTABLE), |_| {},