|
| 1 | +use futures_util::future::{AbortHandle, Abortable}; |
| 2 | +use std::fmt; |
| 3 | +use std::fmt::{Debug, Formatter}; |
| 4 | +use std::future::Future; |
| 5 | +use std::sync::atomic::{AtomicUsize, Ordering}; |
| 6 | +use std::sync::Arc; |
| 7 | +use tokio::runtime::Builder; |
| 8 | +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; |
| 9 | +use tokio::sync::oneshot; |
| 10 | +use tokio::task::{spawn_local, JoinHandle, LocalSet}; |
| 11 | + |
| 12 | +/// A handle to a local pool, used for spawning `!Send` tasks. |
| 13 | +#[derive(Clone)] |
| 14 | +pub struct LocalPoolHandle { |
| 15 | + pool: Arc<LocalPool>, |
| 16 | +} |
| 17 | + |
| 18 | +impl LocalPoolHandle { |
| 19 | + /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this |
| 20 | + /// pool via [`LocalPoolHandle::spawn_pinned`]. |
| 21 | + /// |
| 22 | + /// # Panics |
| 23 | + /// Panics if the pool size is less than one. |
| 24 | + pub fn new(pool_size: usize) -> LocalPoolHandle { |
| 25 | + assert!(pool_size > 0); |
| 26 | + |
| 27 | + let workers = (0..pool_size) |
| 28 | + .map(|_| LocalWorkerHandle::new_worker()) |
| 29 | + .collect(); |
| 30 | + |
| 31 | + let pool = Arc::new(LocalPool { workers }); |
| 32 | + |
| 33 | + LocalPoolHandle { pool } |
| 34 | + } |
| 35 | + |
| 36 | + /// Spawn a task onto a worker thread and pin it there so it can't be moved |
| 37 | + /// off of the thread. Note that the future is not [`Send`], but the |
| 38 | + /// [`FnOnce`] which creates it is. |
| 39 | + /// |
| 40 | + /// # Examples |
| 41 | + /// ``` |
| 42 | + /// use std::rc::Rc; |
| 43 | + /// use tokio_util::task::LocalPoolHandle; |
| 44 | + /// |
| 45 | + /// #[tokio::main] |
| 46 | + /// async fn main() { |
| 47 | + /// // Create the local pool |
| 48 | + /// let pool = LocalPoolHandle::new(1); |
| 49 | + /// |
| 50 | + /// // Spawn a !Send future onto the pool and await it |
| 51 | + /// let output = pool |
| 52 | + /// .spawn_pinned(|| { |
| 53 | + /// // Rc is !Send + !Sync |
| 54 | + /// let local_data = Rc::new("test"); |
| 55 | + /// |
| 56 | + /// // This future holds an Rc, so it is !Send |
| 57 | + /// async move { local_data.to_string() } |
| 58 | + /// }) |
| 59 | + /// .await |
| 60 | + /// .unwrap(); |
| 61 | + /// |
| 62 | + /// assert_eq!(output, "test"); |
| 63 | + /// } |
| 64 | + /// ``` |
| 65 | + pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> |
| 66 | + where |
| 67 | + F: FnOnce() -> Fut, |
| 68 | + F: Send + 'static, |
| 69 | + Fut: Future + 'static, |
| 70 | + Fut::Output: Send + 'static, |
| 71 | + { |
| 72 | + self.pool.spawn_pinned(create_task) |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +impl Debug for LocalPoolHandle { |
| 77 | + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { |
| 78 | + f.write_str("LocalPoolHandle") |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +struct LocalPool { |
| 83 | + workers: Vec<LocalWorkerHandle>, |
| 84 | +} |
| 85 | + |
| 86 | +impl LocalPool { |
| 87 | + /// Spawn a `?Send` future onto a worker |
| 88 | + fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> |
| 89 | + where |
| 90 | + F: FnOnce() -> Fut, |
| 91 | + F: Send + 'static, |
| 92 | + Fut: Future + 'static, |
| 93 | + Fut::Output: Send + 'static, |
| 94 | + { |
| 95 | + let (sender, receiver) = oneshot::channel(); |
| 96 | + |
| 97 | + let (worker, job_guard) = self.find_and_incr_least_burdened_worker(); |
| 98 | + let worker_spawner = worker.spawner.clone(); |
| 99 | + |
| 100 | + // Spawn a future onto the worker's runtime so we can immediately return |
| 101 | + // a join handle. |
| 102 | + worker.runtime_handle.spawn(async move { |
| 103 | + // Move the job guard into the task |
| 104 | + let _job_guard = job_guard; |
| 105 | + |
| 106 | + // Propagate aborts via Abortable/AbortHandle |
| 107 | + let (abort_handle, abort_registration) = AbortHandle::new_pair(); |
| 108 | + let _abort_guard = AbortGuard(abort_handle); |
| 109 | + |
| 110 | + // Inside the future we can't run spawn_local yet because we're not |
| 111 | + // in the context of a LocalSet. We need to send create_task to the |
| 112 | + // LocalSet task for spawning. |
| 113 | + let spawn_task = Box::new(move || { |
| 114 | + // Once we're in the LocalSet context we can call spawn_local |
| 115 | + let join_handle = |
| 116 | + spawn_local( |
| 117 | + async move { Abortable::new(create_task(), abort_registration).await }, |
| 118 | + ); |
| 119 | + |
| 120 | + // Send the join handle back to the spawner. If sending fails, |
| 121 | + // we assume the parent task was canceled, so cancel this task |
| 122 | + // as well. |
| 123 | + if let Err(join_handle) = sender.send(join_handle) { |
| 124 | + join_handle.abort() |
| 125 | + } |
| 126 | + }); |
| 127 | + |
| 128 | + // Send the callback to the LocalSet task |
| 129 | + if let Err(e) = worker_spawner.send(spawn_task) { |
| 130 | + // Propagate the error as a panic in the join handle. |
| 131 | + panic!("Failed to send job to worker: {}", e); |
| 132 | + } |
| 133 | + |
| 134 | + // Wait for the task's join handle |
| 135 | + let join_handle = match receiver.await { |
| 136 | + Ok(handle) => handle, |
| 137 | + Err(e) => { |
| 138 | + // We sent the task successfully, but failed to get its |
| 139 | + // join handle... We assume something happened to the worker |
| 140 | + // and the task was not spawned. Propagate the error as a |
| 141 | + // panic in the join handle. |
| 142 | + panic!("Worker failed to send join handle: {}", e); |
| 143 | + } |
| 144 | + }; |
| 145 | + |
| 146 | + // Wait for the task to complete |
| 147 | + let join_result = join_handle.await; |
| 148 | + |
| 149 | + match join_result { |
| 150 | + Ok(Ok(output)) => output, |
| 151 | + Ok(Err(_)) => { |
| 152 | + // Pinned task was aborted. But that only happens if this |
| 153 | + // task is aborted. So this is an impossible branch. |
| 154 | + unreachable!( |
| 155 | + "Reaching this branch means this task was previously \ |
| 156 | + aborted but it continued running anyways" |
| 157 | + ) |
| 158 | + } |
| 159 | + Err(e) => { |
| 160 | + if e.is_panic() { |
| 161 | + std::panic::resume_unwind(e.into_panic()); |
| 162 | + } else if e.is_cancelled() { |
| 163 | + // No one else should have the join handle, so this is |
| 164 | + // unexpected. Forward this error as a panic in the join |
| 165 | + // handle. |
| 166 | + panic!("spawn_pinned task was canceled: {}", e); |
| 167 | + } else { |
| 168 | + // Something unknown happened (not a panic or |
| 169 | + // cancellation). Forward this error as a panic in the |
| 170 | + // join handle. |
| 171 | + panic!("spawn_pinned task failed: {}", e); |
| 172 | + } |
| 173 | + } |
| 174 | + } |
| 175 | + }) |
| 176 | + } |
| 177 | + |
| 178 | + /// Find the worker with the least number of tasks, increment its task |
| 179 | + /// count, and return its handle. Make sure to actually spawn a task on |
| 180 | + /// the worker so the task count is kept consistent with load. |
| 181 | + /// |
| 182 | + /// A job count guard is also returned to ensure the task count gets |
| 183 | + /// decremented when the job is done. |
| 184 | + fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { |
| 185 | + loop { |
| 186 | + let (worker, task_count) = self |
| 187 | + .workers |
| 188 | + .iter() |
| 189 | + .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) |
| 190 | + .min_by_key(|&(_, count)| count) |
| 191 | + .expect("There must be more than one worker"); |
| 192 | + |
| 193 | + // Make sure the task count hasn't changed since when we choose this |
| 194 | + // worker. Otherwise, restart the search. |
| 195 | + if worker |
| 196 | + .task_count |
| 197 | + .compare_exchange( |
| 198 | + task_count, |
| 199 | + task_count + 1, |
| 200 | + Ordering::SeqCst, |
| 201 | + Ordering::Relaxed, |
| 202 | + ) |
| 203 | + .is_ok() |
| 204 | + { |
| 205 | + return (worker, JobCountGuard(Arc::clone(&worker.task_count))); |
| 206 | + } |
| 207 | + } |
| 208 | + } |
| 209 | +} |
| 210 | + |
| 211 | +/// Automatically decrements a worker's job count when a job finishes (when |
| 212 | +/// this gets dropped). |
| 213 | +struct JobCountGuard(Arc<AtomicUsize>); |
| 214 | + |
| 215 | +impl Drop for JobCountGuard { |
| 216 | + fn drop(&mut self) { |
| 217 | + // Decrement the job count |
| 218 | + let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); |
| 219 | + debug_assert!(previous_value >= 1); |
| 220 | + } |
| 221 | +} |
| 222 | + |
| 223 | +/// Calls abort on the handle when dropped. |
| 224 | +struct AbortGuard(AbortHandle); |
| 225 | + |
| 226 | +impl Drop for AbortGuard { |
| 227 | + fn drop(&mut self) { |
| 228 | + self.0.abort(); |
| 229 | + } |
| 230 | +} |
| 231 | + |
| 232 | +type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>; |
| 233 | + |
| 234 | +struct LocalWorkerHandle { |
| 235 | + runtime_handle: tokio::runtime::Handle, |
| 236 | + spawner: UnboundedSender<PinnedFutureSpawner>, |
| 237 | + task_count: Arc<AtomicUsize>, |
| 238 | +} |
| 239 | + |
| 240 | +impl LocalWorkerHandle { |
| 241 | + /// Create a new worker for executing pinned tasks |
| 242 | + fn new_worker() -> LocalWorkerHandle { |
| 243 | + let (sender, receiver) = unbounded_channel(); |
| 244 | + let runtime = Builder::new_current_thread() |
| 245 | + .enable_all() |
| 246 | + .build() |
| 247 | + .expect("Failed to start a pinned worker thread runtime"); |
| 248 | + let runtime_handle = runtime.handle().clone(); |
| 249 | + let task_count = Arc::new(AtomicUsize::new(0)); |
| 250 | + let task_count_clone = Arc::clone(&task_count); |
| 251 | + |
| 252 | + std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); |
| 253 | + |
| 254 | + LocalWorkerHandle { |
| 255 | + runtime_handle, |
| 256 | + spawner: sender, |
| 257 | + task_count, |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + fn run( |
| 262 | + runtime: tokio::runtime::Runtime, |
| 263 | + mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, |
| 264 | + task_count: Arc<AtomicUsize>, |
| 265 | + ) { |
| 266 | + let local_set = LocalSet::new(); |
| 267 | + local_set.block_on(&runtime, async { |
| 268 | + while let Some(spawn_task) = task_receiver.recv().await { |
| 269 | + // Calls spawn_local(future) |
| 270 | + (spawn_task)(); |
| 271 | + } |
| 272 | + }); |
| 273 | + |
| 274 | + // If there are any tasks on the runtime associated with a LocalSet task |
| 275 | + // that has already completed, but whose output has not yet been |
| 276 | + // reported, let that task complete. |
| 277 | + // |
| 278 | + // Since the task_count is decremented when the runtime task exits, |
| 279 | + // reading that counter lets us know if any such tasks completed during |
| 280 | + // the call to `block_on`. |
| 281 | + // |
| 282 | + // Tasks on the LocalSet can't complete during this loop since they're |
| 283 | + // stored on the LocalSet and we aren't accessing it. |
| 284 | + let mut previous_task_count = task_count.load(Ordering::SeqCst); |
| 285 | + loop { |
| 286 | + // This call will also run tasks spawned on the runtime. |
| 287 | + runtime.block_on(tokio::task::yield_now()); |
| 288 | + let new_task_count = task_count.load(Ordering::SeqCst); |
| 289 | + if new_task_count == previous_task_count { |
| 290 | + break; |
| 291 | + } else { |
| 292 | + previous_task_count = new_task_count; |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + // It's now no longer possible for a task on the runtime to be |
| 297 | + // associated with a LocalSet task that has completed. Drop both the |
| 298 | + // LocalSet and runtime to let tasks on the runtime be cancelled if and |
| 299 | + // only if they are still on the LocalSet. |
| 300 | + // |
| 301 | + // Drop the LocalSet task first so that anyone awaiting the runtime |
| 302 | + // JoinHandle will see the cancelled error after the LocalSet task |
| 303 | + // destructor has completed. |
| 304 | + drop(local_set); |
| 305 | + drop(runtime); |
| 306 | + } |
| 307 | +} |
0 commit comments