Skip to content

Commit 257053e

Browse files
authored
util: add spawn_pinned (#3370)
1 parent 5af9e0d commit 257053e

File tree

5 files changed

+506
-1
lines changed

5 files changed

+506
-1
lines changed

tokio-util/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ codec = []
3030
time = ["tokio/time","slab"]
3131
io = []
3232
io-util = ["io", "tokio/rt", "tokio/io-util"]
33-
rt = ["tokio/rt"]
33+
rt = ["tokio/rt", "tokio/sync", "futures-util"]
3434

3535
__docs_rs = ["futures-util"]
3636

tokio-util/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cfg_io! {
4343

4444
cfg_rt! {
4545
pub mod context;
46+
pub mod task;
4647
}
4748

4849
cfg_time! {

tokio-util/src/task/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
//! Extra utilities for spawning tasks
2+
3+
mod spawn_pinned;
4+
pub use spawn_pinned::LocalPoolHandle;
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
use futures_util::future::{AbortHandle, Abortable};
2+
use std::fmt;
3+
use std::fmt::{Debug, Formatter};
4+
use std::future::Future;
5+
use std::sync::atomic::{AtomicUsize, Ordering};
6+
use std::sync::Arc;
7+
use tokio::runtime::Builder;
8+
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9+
use tokio::sync::oneshot;
10+
use tokio::task::{spawn_local, JoinHandle, LocalSet};
11+
12+
/// A handle to a local pool, used for spawning `!Send` tasks.
13+
#[derive(Clone)]
14+
pub struct LocalPoolHandle {
15+
pool: Arc<LocalPool>,
16+
}
17+
18+
impl LocalPoolHandle {
19+
/// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
20+
/// pool via [`LocalPoolHandle::spawn_pinned`].
21+
///
22+
/// # Panics
23+
/// Panics if the pool size is less than one.
24+
pub fn new(pool_size: usize) -> LocalPoolHandle {
25+
assert!(pool_size > 0);
26+
27+
let workers = (0..pool_size)
28+
.map(|_| LocalWorkerHandle::new_worker())
29+
.collect();
30+
31+
let pool = Arc::new(LocalPool { workers });
32+
33+
LocalPoolHandle { pool }
34+
}
35+
36+
/// Spawn a task onto a worker thread and pin it there so it can't be moved
37+
/// off of the thread. Note that the future is not [`Send`], but the
38+
/// [`FnOnce`] which creates it is.
39+
///
40+
/// # Examples
41+
/// ```
42+
/// use std::rc::Rc;
43+
/// use tokio_util::task::LocalPoolHandle;
44+
///
45+
/// #[tokio::main]
46+
/// async fn main() {
47+
/// // Create the local pool
48+
/// let pool = LocalPoolHandle::new(1);
49+
///
50+
/// // Spawn a !Send future onto the pool and await it
51+
/// let output = pool
52+
/// .spawn_pinned(|| {
53+
/// // Rc is !Send + !Sync
54+
/// let local_data = Rc::new("test");
55+
///
56+
/// // This future holds an Rc, so it is !Send
57+
/// async move { local_data.to_string() }
58+
/// })
59+
/// .await
60+
/// .unwrap();
61+
///
62+
/// assert_eq!(output, "test");
63+
/// }
64+
/// ```
65+
pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
66+
where
67+
F: FnOnce() -> Fut,
68+
F: Send + 'static,
69+
Fut: Future + 'static,
70+
Fut::Output: Send + 'static,
71+
{
72+
self.pool.spawn_pinned(create_task)
73+
}
74+
}
75+
76+
impl Debug for LocalPoolHandle {
77+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
78+
f.write_str("LocalPoolHandle")
79+
}
80+
}
81+
82+
struct LocalPool {
83+
workers: Vec<LocalWorkerHandle>,
84+
}
85+
86+
impl LocalPool {
87+
/// Spawn a `?Send` future onto a worker
88+
fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
89+
where
90+
F: FnOnce() -> Fut,
91+
F: Send + 'static,
92+
Fut: Future + 'static,
93+
Fut::Output: Send + 'static,
94+
{
95+
let (sender, receiver) = oneshot::channel();
96+
97+
let (worker, job_guard) = self.find_and_incr_least_burdened_worker();
98+
let worker_spawner = worker.spawner.clone();
99+
100+
// Spawn a future onto the worker's runtime so we can immediately return
101+
// a join handle.
102+
worker.runtime_handle.spawn(async move {
103+
// Move the job guard into the task
104+
let _job_guard = job_guard;
105+
106+
// Propagate aborts via Abortable/AbortHandle
107+
let (abort_handle, abort_registration) = AbortHandle::new_pair();
108+
let _abort_guard = AbortGuard(abort_handle);
109+
110+
// Inside the future we can't run spawn_local yet because we're not
111+
// in the context of a LocalSet. We need to send create_task to the
112+
// LocalSet task for spawning.
113+
let spawn_task = Box::new(move || {
114+
// Once we're in the LocalSet context we can call spawn_local
115+
let join_handle =
116+
spawn_local(
117+
async move { Abortable::new(create_task(), abort_registration).await },
118+
);
119+
120+
// Send the join handle back to the spawner. If sending fails,
121+
// we assume the parent task was canceled, so cancel this task
122+
// as well.
123+
if let Err(join_handle) = sender.send(join_handle) {
124+
join_handle.abort()
125+
}
126+
});
127+
128+
// Send the callback to the LocalSet task
129+
if let Err(e) = worker_spawner.send(spawn_task) {
130+
// Propagate the error as a panic in the join handle.
131+
panic!("Failed to send job to worker: {}", e);
132+
}
133+
134+
// Wait for the task's join handle
135+
let join_handle = match receiver.await {
136+
Ok(handle) => handle,
137+
Err(e) => {
138+
// We sent the task successfully, but failed to get its
139+
// join handle... We assume something happened to the worker
140+
// and the task was not spawned. Propagate the error as a
141+
// panic in the join handle.
142+
panic!("Worker failed to send join handle: {}", e);
143+
}
144+
};
145+
146+
// Wait for the task to complete
147+
let join_result = join_handle.await;
148+
149+
match join_result {
150+
Ok(Ok(output)) => output,
151+
Ok(Err(_)) => {
152+
// Pinned task was aborted. But that only happens if this
153+
// task is aborted. So this is an impossible branch.
154+
unreachable!(
155+
"Reaching this branch means this task was previously \
156+
aborted but it continued running anyways"
157+
)
158+
}
159+
Err(e) => {
160+
if e.is_panic() {
161+
std::panic::resume_unwind(e.into_panic());
162+
} else if e.is_cancelled() {
163+
// No one else should have the join handle, so this is
164+
// unexpected. Forward this error as a panic in the join
165+
// handle.
166+
panic!("spawn_pinned task was canceled: {}", e);
167+
} else {
168+
// Something unknown happened (not a panic or
169+
// cancellation). Forward this error as a panic in the
170+
// join handle.
171+
panic!("spawn_pinned task failed: {}", e);
172+
}
173+
}
174+
}
175+
})
176+
}
177+
178+
/// Find the worker with the least number of tasks, increment its task
179+
/// count, and return its handle. Make sure to actually spawn a task on
180+
/// the worker so the task count is kept consistent with load.
181+
///
182+
/// A job count guard is also returned to ensure the task count gets
183+
/// decremented when the job is done.
184+
fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
185+
loop {
186+
let (worker, task_count) = self
187+
.workers
188+
.iter()
189+
.map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
190+
.min_by_key(|&(_, count)| count)
191+
.expect("There must be more than one worker");
192+
193+
// Make sure the task count hasn't changed since when we choose this
194+
// worker. Otherwise, restart the search.
195+
if worker
196+
.task_count
197+
.compare_exchange(
198+
task_count,
199+
task_count + 1,
200+
Ordering::SeqCst,
201+
Ordering::Relaxed,
202+
)
203+
.is_ok()
204+
{
205+
return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
206+
}
207+
}
208+
}
209+
}
210+
211+
/// Automatically decrements a worker's job count when a job finishes (when
212+
/// this gets dropped).
213+
struct JobCountGuard(Arc<AtomicUsize>);
214+
215+
impl Drop for JobCountGuard {
216+
fn drop(&mut self) {
217+
// Decrement the job count
218+
let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
219+
debug_assert!(previous_value >= 1);
220+
}
221+
}
222+
223+
/// Calls abort on the handle when dropped.
224+
struct AbortGuard(AbortHandle);
225+
226+
impl Drop for AbortGuard {
227+
fn drop(&mut self) {
228+
self.0.abort();
229+
}
230+
}
231+
232+
type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
233+
234+
struct LocalWorkerHandle {
235+
runtime_handle: tokio::runtime::Handle,
236+
spawner: UnboundedSender<PinnedFutureSpawner>,
237+
task_count: Arc<AtomicUsize>,
238+
}
239+
240+
impl LocalWorkerHandle {
241+
/// Create a new worker for executing pinned tasks
242+
fn new_worker() -> LocalWorkerHandle {
243+
let (sender, receiver) = unbounded_channel();
244+
let runtime = Builder::new_current_thread()
245+
.enable_all()
246+
.build()
247+
.expect("Failed to start a pinned worker thread runtime");
248+
let runtime_handle = runtime.handle().clone();
249+
let task_count = Arc::new(AtomicUsize::new(0));
250+
let task_count_clone = Arc::clone(&task_count);
251+
252+
std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
253+
254+
LocalWorkerHandle {
255+
runtime_handle,
256+
spawner: sender,
257+
task_count,
258+
}
259+
}
260+
261+
fn run(
262+
runtime: tokio::runtime::Runtime,
263+
mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
264+
task_count: Arc<AtomicUsize>,
265+
) {
266+
let local_set = LocalSet::new();
267+
local_set.block_on(&runtime, async {
268+
while let Some(spawn_task) = task_receiver.recv().await {
269+
// Calls spawn_local(future)
270+
(spawn_task)();
271+
}
272+
});
273+
274+
// If there are any tasks on the runtime associated with a LocalSet task
275+
// that has already completed, but whose output has not yet been
276+
// reported, let that task complete.
277+
//
278+
// Since the task_count is decremented when the runtime task exits,
279+
// reading that counter lets us know if any such tasks completed during
280+
// the call to `block_on`.
281+
//
282+
// Tasks on the LocalSet can't complete during this loop since they're
283+
// stored on the LocalSet and we aren't accessing it.
284+
let mut previous_task_count = task_count.load(Ordering::SeqCst);
285+
loop {
286+
// This call will also run tasks spawned on the runtime.
287+
runtime.block_on(tokio::task::yield_now());
288+
let new_task_count = task_count.load(Ordering::SeqCst);
289+
if new_task_count == previous_task_count {
290+
break;
291+
} else {
292+
previous_task_count = new_task_count;
293+
}
294+
}
295+
296+
// It's now no longer possible for a task on the runtime to be
297+
// associated with a LocalSet task that has completed. Drop both the
298+
// LocalSet and runtime to let tasks on the runtime be cancelled if and
299+
// only if they are still on the LocalSet.
300+
//
301+
// Drop the LocalSet task first so that anyone awaiting the runtime
302+
// JoinHandle will see the cancelled error after the LocalSet task
303+
// destructor has completed.
304+
drop(local_set);
305+
drop(runtime);
306+
}
307+
}

0 commit comments

Comments
 (0)