Skip to content
This repository was archived by the owner on Nov 5, 2018. It is now read-only.

Commit 0b7f95c

Browse files
jeehoonkangStjepan Glavina
authored andcommitted
Add join for scoped thread API (#7)
* Add join for scoped thread API * Add documentation (@Vtec234's comments) * Remove Scope::join()
1 parent daed549 commit 0b7f95c

1 file changed

Lines changed: 79 additions & 62 deletions

File tree

src/scoped.rs

Lines changed: 79 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -109,22 +109,20 @@
109109

110110
use std::cell::RefCell;
111111
use std::fmt;
112+
use std::marker::PhantomData;
112113
use std::mem;
114+
use std::ops::DerefMut;
113115
use std::rc::Rc;
114-
use std::sync::atomic::Ordering;
115-
use std::sync::Arc;
116116
use std::thread;
117117
use std::io;
118118

119-
use atomic_option::AtomicOption;
120-
121119
#[doc(hidden)]
122-
trait FnBox {
123-
fn call_box(self: Box<Self>);
120+
trait FnBox<T> {
121+
fn call_box(self: Box<Self>) -> T;
124122
}
125123

126-
impl<F: FnOnce()> FnBox for F {
127-
fn call_box(self: Box<Self>) {
124+
impl<T, F: FnOnce() -> T> FnBox<T> for F {
125+
fn call_box(self: Box<Self>) -> T {
128126
(*self)()
129127
}
130128
}
@@ -146,47 +144,61 @@ pub unsafe fn builder_spawn_unsafe<'a, F>(
146144
where
147145
F: FnOnce() + Send + 'a,
148146
{
149-
use std::mem;
150-
151-
let closure: Box<FnBox + 'a> = Box::new(f);
152-
let closure: Box<FnBox + Send> = mem::transmute(closure);
147+
let closure: Box<FnBox<()> + 'a> = Box::new(f);
148+
let closure: Box<FnBox<()> + Send> = mem::transmute(closure);
153149
builder.spawn(move || closure.call_box())
154150
}
155151

156-
157152
pub struct Scope<'a> {
158-
dtors: RefCell<Option<DtorChain<'a>>>,
153+
/// The list of the deferred functions and thread join jobs.
154+
dtors: RefCell<Option<DtorChain<'a, ()>>>,
155+
// !Send + !Sync
156+
_marker: PhantomData<*const ()>,
159157
}
160158

161-
struct DtorChain<'a> {
162-
dtor: Box<FnBox + 'a>,
163-
next: Option<Box<DtorChain<'a>>>,
159+
struct DtorChain<'a, T> {
160+
dtor: Box<FnBox<T> + 'a>,
161+
next: Option<Box<DtorChain<'a, T>>>,
164162
}
165163

166-
enum JoinState {
167-
Running(thread::JoinHandle<()>),
168-
Joined,
164+
impl<'a, T> DtorChain<'a, T> {
165+
pub fn pop(chain: &mut Option<DtorChain<'a, T>>) -> Option<Box<FnBox<T> + 'a>> {
166+
chain.take().map(|mut node| {
167+
*chain = node.next.take().map(|b| *b);
168+
node.dtor
169+
})
170+
}
169171
}
170172

171-
impl JoinState {
172-
fn join(&mut self) {
173-
let mut state = JoinState::Joined;
174-
mem::swap(self, &mut state);
175-
if let JoinState::Running(handle) = state {
176-
let res = handle.join();
173+
struct JoinState<T> {
174+
join_handle: thread::JoinHandle<()>,
175+
result: usize,
176+
_marker: PhantomData<T>,
177+
}
177178

178-
if !thread::panicking() {
179-
res.unwrap();
180-
}
179+
impl<T: Send> JoinState<T> {
180+
fn new(join_handle: thread::JoinHandle<()>, result: usize) -> JoinState<T> {
181+
JoinState {
182+
join_handle: join_handle,
183+
result: result,
184+
_marker: PhantomData,
181185
}
182186
}
187+
188+
fn join(self) -> thread::Result<T> {
189+
let result = self.result;
190+
self.join_handle.join().map(|_| {
191+
unsafe { *Box::from_raw(result as *mut T) }
192+
})
193+
}
183194
}
184195

185196
/// A handle to a scoped thread
186-
pub struct ScopedJoinHandle<T> {
187-
inner: Rc<RefCell<JoinState>>,
188-
packet: Arc<AtomicOption<T>>,
197+
pub struct ScopedJoinHandle<'a, T: 'a> {
198+
// !Send + !Sync
199+
inner: Rc<RefCell<Option<JoinState<T>>>>,
189200
thread: thread::Thread,
201+
_marker: PhantomData<&'a T>,
190202
}
191203

192204
/// Create a new `scope`, for deferred destructors.
@@ -204,11 +216,18 @@ pub struct ScopedJoinHandle<T> {
204216
/// });
205217
/// // Prints messages in the reverse order written
206218
/// ```
219+
///
220+
/// # Panics
221+
///
222+
/// `scoped::scope()` panics if a spawned thread panics but it is not joined inside the scope.
207223
pub fn scope<'a, F, R>(f: F) -> R
208224
where
209225
F: FnOnce(&Scope<'a>) -> R,
210226
{
211-
let mut scope = Scope { dtors: RefCell::new(None) };
227+
let mut scope = Scope {
228+
dtors: RefCell::new(None),
229+
_marker: PhantomData,
230+
};
212231
let ret = f(&scope);
213232
scope.drop_all();
214233
ret
@@ -220,7 +239,7 @@ impl<'a> fmt::Debug for Scope<'a> {
220239
}
221240
}
222241

223-
impl<T> fmt::Debug for ScopedJoinHandle<T> {
242+
impl<'a, T> fmt::Debug for ScopedJoinHandle<'a, T> {
224243
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
225244
write!(f, "ScopedJoinHandle {{ ... }}")
226245
}
@@ -233,26 +252,16 @@ impl<'a> Scope<'a> {
233252
// method outside of any destructor, we avoid any leakage problems
234253
// due to @rust-lang/rust#14875.
235254
fn drop_all(&mut self) {
236-
loop {
237-
// use a separate scope to ensure that the RefCell borrow
238-
// is relinquished before running `dtor`
239-
let dtor = {
240-
let mut dtors = self.dtors.borrow_mut();
241-
if let Some(mut node) = dtors.take() {
242-
*dtors = node.next.take().map(|b| *b);
243-
node.dtor
244-
} else {
245-
return;
246-
}
247-
};
248-
dtor.call_box()
255+
while let Some(dtor) = DtorChain::pop(&mut self.dtors.borrow_mut()) {
256+
dtor.call_box();
249257
}
250258
}
251259

252260
/// Schedule code to be executed when exiting the scope.
253261
///
254262
/// This is akin to having a destructor on the stack, except that it is
255-
/// *guaranteed* to be run.
263+
/// *guaranteed* to be run. It is guaranteed that the function is called
264+
/// after all the spawned threads are joined.
256265
pub fn defer<F>(&self, f: F)
257266
where
258267
F: FnOnce() + 'a,
@@ -273,8 +282,9 @@ impl<'a> Scope<'a> {
273282
/// scope exits.
274283
///
275284
/// [spawn]: http://doc.rust-lang.org/std/thread/fn.spawn.html
276-
pub fn spawn<F, T>(&self, f: F) -> ScopedJoinHandle<T>
285+
pub fn spawn<'s, F, T>(&'s self, f: F) -> ScopedJoinHandle<'a, T>
277286
where
287+
'a: 's,
278288
F: FnOnce() -> T + Send + 'a,
279289
T: Send + 'a,
280290
{
@@ -313,42 +323,49 @@ impl<'s, 'a: 's> ScopedThreadBuilder<'s, 'a> {
313323
}
314324

315325
/// Spawns a new thread, and returns a join handle for it.
316-
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<T>>
326+
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<'a, T>>
317327
where
318328
F: FnOnce() -> T + Send + 'a,
319329
T: Send + 'a,
320330
{
321-
let their_packet = Arc::new(AtomicOption::new());
322-
let my_packet = their_packet.clone();
331+
// The `Box` constructed below is written only by the spawned thread,
332+
// and read by the current thread only after the spawned thread is
333+
// joined (`JoinState::join()`). Thus there are no data races.
334+
let result = Box::into_raw(Box::<T>::new(unsafe { mem::uninitialized() })) as usize;
323335

324336
let join_handle = try!(unsafe {
325337
builder_spawn_unsafe(self.builder, move || {
326-
their_packet.swap(f(), Ordering::Relaxed);
338+
let mut result = Box::from_raw(result as *mut T);
339+
*result = f();
340+
mem::forget(result);
327341
})
328342
});
329-
330343
let thread = join_handle.thread().clone();
331-
let deferred_handle = Rc::new(RefCell::new(JoinState::Running(join_handle)));
344+
345+
let join_state = JoinState::<T>::new(join_handle, result);
346+
let deferred_handle = Rc::new(RefCell::new(Some(join_state)));
332347
let my_handle = deferred_handle.clone();
333348

334349
self.scope.defer(move || {
335-
let mut state = deferred_handle.borrow_mut();
336-
state.join();
350+
let state = mem::replace(deferred_handle.borrow_mut().deref_mut(), None);
351+
if let Some(state) = state {
352+
state.join().unwrap();
353+
}
337354
});
338355

339356
Ok(ScopedJoinHandle {
340357
inner: my_handle,
341-
packet: my_packet,
342358
thread: thread,
359+
_marker: PhantomData,
343360
})
344361
}
345362
}
346363

347-
impl<T> ScopedJoinHandle<T> {
364+
impl<'a, T: Send + 'a> ScopedJoinHandle<'a, T> {
348365
/// Join the scoped thread, returning the result it produced.
349-
pub fn join(self) -> T {
350-
self.inner.borrow_mut().join();
351-
self.packet.take(Ordering::Relaxed).unwrap()
366+
pub fn join(self) -> thread::Result<T> {
367+
let state = mem::replace(self.inner.borrow_mut().deref_mut(), None);
368+
state.unwrap().join()
352369
}
353370

354371
/// Get the underlying thread handle.

0 commit comments

Comments
 (0)