Skip to content
26 changes: 26 additions & 0 deletions rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
/// Closure invoked on worker thread exit.
exit_handler: Option<Box<ExitHandler>>,

/// Affects the blocking/work-stealing behavior when using nested thread pools.
full_blocking: bool,

/// Closure invoked to spawn threads.
spawn_handler: S,

Expand Down Expand Up @@ -245,6 +248,7 @@ impl Default for ThreadPoolBuilder {
exit_handler: None,
spawn_handler: DefaultSpawn,
breadth_first: false,
full_blocking: false,
}
}
}
Expand Down Expand Up @@ -455,6 +459,7 @@ impl<S> ThreadPoolBuilder<S> {
start_handler: self.start_handler,
exit_handler: self.exit_handler,
breadth_first: self.breadth_first,
full_blocking: self.full_blocking,
}
}

Expand Down Expand Up @@ -672,6 +677,25 @@ impl<S> ThreadPoolBuilder<S> {
self.exit_handler = Some(Box::new(exit_handler));
self
}

/// Changes the behavior of nested thread pools.
///
/// If false, when a job is created on this thread pool by a job running in a separate thread
/// pool, the parent thread is allowed to start executing a new job in the parent thread pool.
///
/// If true, when a job is created on this thread pool by a job running in a separate thread
/// pool, the parent thread will block until the jobs in this thread pool are completed. This
/// is useful for avoiding deadlock when using mutexes.
///
/// Default is false.
pub fn full_blocking(mut self) -> Self {
self.full_blocking = true;
self
}

fn get_full_blocking(&self) -> bool {
self.full_blocking
}
}

#[allow(deprecated)]
Expand Down Expand Up @@ -811,6 +835,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
ref exit_handler,
spawn_handler: _,
ref breadth_first,
ref full_blocking,
} = *self;

// Just print `Some(<closure>)` or `None` to the debug
Expand All @@ -835,6 +860,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
.field("start_handler", &start_handler)
.field("exit_handler", &exit_handler)
.field("breadth_first", &breadth_first)
.field("full_blocking", &full_blocking)
.finish()
}
}
Expand Down
32 changes: 31 additions & 1 deletion rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub(super) struct Registry {
panic_handler: Option<Box<PanicHandler>>,
start_handler: Option<Box<StartHandler>>,
exit_handler: Option<Box<ExitHandler>>,
full_blocking: bool,

// When this latch reaches 0, it means that all work on this
// registry must be complete. This is ensured in the following ways:
Expand Down Expand Up @@ -267,6 +268,7 @@ impl Registry {
panic_handler: builder.take_panic_handler(),
start_handler: builder.take_start_handler(),
exit_handler: builder.take_exit_handler(),
full_blocking: builder.get_full_blocking(),
});

// If we return early or panic, make sure to terminate existing threads.
Expand Down Expand Up @@ -493,7 +495,11 @@ impl Registry {
if worker_thread.is_null() {
self.in_worker_cold(op)
} else if (*worker_thread).registry().id() != self.id() {
self.in_worker_cross(&*worker_thread, op)
if self.full_blocking {
self.in_worker_cross_blocking(op)
} else {
self.in_worker_cross(&*worker_thread, op)
}
} else {
// Perfectly valid to give them a `&T`: this is the
// current thread, so we know the data structure won't be
Expand Down Expand Up @@ -552,6 +558,30 @@ impl Registry {
job.into_result()
}

#[cold]
unsafe fn in_worker_cross_blocking<OP, R>(&self, op: OP) -> R
where
OP: FnOnce(&WorkerThread, bool) -> R + Send,
R: Send,
{
thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());

LOCK_LATCH.with(|l| {
let job = StackJob::new(
|injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
},
LatchRef::new(l),
);
self.inject(job.as_job_ref());
job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.

job.into_result()
})
}

/// Increments the terminate counter. This increment should be
/// balanced by a call to `terminate`, which will decrement. This
/// is used when spawning asynchronous work, which needs to
Expand Down
45 changes: 45 additions & 0 deletions rayon-core/src/thread_pool/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};

use crate::{join, Scope, ScopeFifo, ThreadPool, ThreadPoolBuilder};

Expand Down Expand Up @@ -416,3 +418,46 @@ fn yield_local_to_spawn() {
// for it to finish if a different thread stole it first.
assert_eq!(22, rx.recv().unwrap());
}

#[test]
fn nested_thread_pools_deadlock() {
let global_pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
// The lock thread pool must be full_blocking for this test to pass.
let lock_pool = Arc::new(
ThreadPoolBuilder::new()
.full_blocking()
.num_threads(1)
.build()
.unwrap(),
);
let mutex = Arc::new(Mutex::new(()));
let start_time = Instant::now();

global_pool.scope(|s| {
for i in 0..5 {
let mutex = mutex.clone();
let lock_pool = lock_pool.clone();
// Create 5 jobs that try to acquire the lock.
// If all 5 jobs are unable the acquire the lock in 2 seconds, deadlock occurred.
s.spawn(move |_| {
let mut acquired = false;
while start_time.elapsed() < Duration::from_secs(2) {
if let Ok(_guard) = mutex.try_lock() {
println!("Thread {i} acquired the mutex");
lock_pool.scope(|lock_s| {
lock_s.spawn(|_| {
thread::sleep(Duration::from_millis(100));
});
});
acquired = true;
break;
}
thread::sleep(Duration::from_millis(10));
}
if !acquired {
panic!("Thread {i} failed to acquire the mutex within 2 seconds.");
}
});
}
});
}
25 changes: 25 additions & 0 deletions tests/issue592.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::sync::{Arc, Mutex};
use rayon::ThreadPoolBuilder;
use rayon::iter::IntoParallelRefIterator;
use rayon::iter::ParallelIterator;

fn mutex_and_par(mutex: Arc<Mutex<Vec<i32>>>, blocking_pool: &rayon::ThreadPool) {
// Lock the mutex and collect items using the full blocking thread pool
let vec = mutex.lock().unwrap();
let result: Vec<i32> = blocking_pool.install(|| vec.par_iter().cloned().collect());
println!("{:?}", result);
}

#[test]
fn test_issue592() {
let collection = vec![1, 2, 3, 4, 5];
let mutex = Arc::new(Mutex::new(collection));

let blocking_pool = ThreadPoolBuilder::new().full_blocking().num_threads(4).build().unwrap();

let dummy_collection: Vec<i32> = (1..=100).collect();
dummy_collection.par_iter().for_each(|_| {
mutex_and_par(mutex.clone(), &blocking_pool);
});
}