Skip to content

Commit 383d574

Browse files
committed
future: spawn_blocking when waiting on cond var
1 parent 0202780 commit 383d574

File tree

2 files changed

+66
-26
lines changed

2 files changed

+66
-26
lines changed

scylla-rust-wrapper/src/future.rs

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -108,61 +108,96 @@ impl CassFuture {
108108
})
109109
}
110110

111-
pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
111+
pub fn with_waited_result<T>(
112+
self: &Arc<Self>,
113+
f: impl FnOnce(&mut CassFutureResult) -> T,
114+
) -> T {
112115
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
113116
}
114117

115-
fn with_waited_state<T>(&self, f: impl FnOnce(&mut CassFutureState) -> T) -> T {
118+
fn with_waited_state<T>(self: &Arc<Self>, f: impl FnOnce(&mut CassFutureState) -> T) -> T {
116119
let mut guard = self.state.lock().unwrap();
117120
let handle = guard.join_handle.take();
121+
mem::drop(guard);
122+
118123
if let Some(handle) = handle {
119-
mem::drop(guard);
120124
// unwrap: JoinError appears only when future either panic'ed or canceled.
121125
RUNTIME.block_on(handle).unwrap();
122126
guard = self.state.lock().unwrap();
123127
} else {
124-
guard = self
125-
.wait_for_value
126-
.wait_while(guard, |state| state.value.is_none())
127-
// unwrap: Error appears only when mutex is poisoned.
128-
.unwrap();
128+
// We need to spawn a thread that will be responsible for waiting
129+
// on a cond variable (this operation blocks the thread).
130+
// If we don't do that, this won't work for `current-thread` tokio runtime..
131+
let self_clone = Arc::clone(self);
132+
let wait_for_value = async {
133+
RUNTIME
134+
.spawn_blocking(move || {
135+
let mut guard = self_clone.state.lock().unwrap();
136+
guard = self_clone
137+
.wait_for_value
138+
.wait_while(guard, |state| state.value.is_none())
139+
// unwrap: Error appears only when mutex is poisoned.
140+
.unwrap();
141+
mem::drop(guard);
142+
})
143+
.await
144+
};
145+
RUNTIME.block_on(wait_for_value).unwrap();
146+
guard = self.state.lock().unwrap();
129147
}
130148
f(&mut guard)
131149
}
132150

133151
fn with_waited_result_timed<T>(
134-
&self,
152+
self: &Arc<Self>,
135153
f: impl FnOnce(&mut CassFutureResult) -> T,
136154
timeout_duration: Duration,
137155
) -> Result<T, TimeoutError> {
138156
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
139157
}
140158

141159
pub(self) fn with_waited_state_timed<T>(
142-
&self,
160+
self: &Arc<Self>,
143161
f: impl FnOnce(&mut CassFutureState) -> T,
144162
timeout_duration: Duration,
145163
) -> Result<T, TimeoutError> {
146164
let mut guard = self.state.lock().unwrap();
147165
let handle = guard.join_handle.take();
166+
mem::drop(guard);
167+
148168
if let Some(handle) = handle {
149-
mem::drop(guard);
150169
// Need to wrap it with async{} block, so the timeout is lazily executed inside the runtime.
151170
// See mention about panics: https://docs.rs/tokio/latest/tokio/time/fn.timeout.html.
152171
let timed = async { tokio::time::timeout(timeout_duration, handle).await };
153172
// unwrap: JoinError appears only when future either panic'ed or canceled.
154173
RUNTIME.block_on(timed).map_err(|_| TimeoutError)?.unwrap();
155174
guard = self.state.lock().unwrap();
156175
} else {
157-
let (guard_result, timeout_result) = self
158-
.wait_for_value
159-
.wait_timeout_while(guard, timeout_duration, |state| state.value.is_none())
160-
// unwrap: Error appears only when mutex is poisoned.
161-
.unwrap();
176+
// We need to spawn a thread that will be responsible for waiting
177+
// on a cond variable (this operation blocks the thread).
178+
// If we don't do that, this won't work for `current-thread` tokio runtime..
179+
let self_clone = Arc::clone(self);
180+
let wait_for_value = async {
181+
RUNTIME
182+
.spawn_blocking(move || {
183+
let guard = self_clone.state.lock().unwrap();
184+
let (guard_result, timeout_result) = self_clone
185+
.wait_for_value
186+
.wait_timeout_while(guard, timeout_duration, |state| {
187+
state.value.is_none()
188+
})
189+
// unwrap: Error appears only when mutex is poisoned.
190+
.unwrap();
191+
mem::drop(guard_result);
192+
timeout_result
193+
})
194+
.await
195+
};
196+
let timeout_result = RUNTIME.block_on(wait_for_value).unwrap();
162197
if timeout_result.timed_out() {
163198
return Err(TimeoutError);
164199
}
165-
guard = guard_result;
200+
guard = self.state.lock().unwrap();
166201
}
167202

168203
Ok(f(&mut guard))
@@ -208,15 +243,15 @@ pub unsafe extern "C" fn cass_future_set_callback(
208243

209244
#[no_mangle]
210245
pub unsafe extern "C" fn cass_future_wait(future_raw: *const CassFuture) {
211-
ptr_to_ref(future_raw).with_waited_result(|_| ());
246+
clone_arced(future_raw).with_waited_result(|_| ());
212247
}
213248

214249
#[no_mangle]
215250
pub unsafe extern "C" fn cass_future_wait_timed(
216251
future_raw: *const CassFuture,
217252
timeout_us: cass_duration_t,
218253
) -> cass_bool_t {
219-
ptr_to_ref(future_raw)
254+
clone_arced(future_raw)
220255
.with_waited_result_timed(|_| (), Duration::from_micros(timeout_us))
221256
.is_ok() as cass_bool_t
222257
}
@@ -232,7 +267,7 @@ pub unsafe extern "C" fn cass_future_ready(future_raw: *const CassFuture) -> cas
232267

233268
#[no_mangle]
234269
pub unsafe extern "C" fn cass_future_error_code(future_raw: *const CassFuture) -> CassError {
235-
ptr_to_ref(future_raw).with_waited_result(|r: &mut CassFutureResult| match r {
270+
clone_arced(future_raw).with_waited_result(|r: &mut CassFutureResult| match r {
236271
Ok(CassResultValue::QueryError(err)) => CassError::from(err.as_ref()),
237272
Err((err, _)) => *err,
238273
_ => CassError::CASS_OK,
@@ -245,7 +280,7 @@ pub unsafe extern "C" fn cass_future_error_message(
245280
message: *mut *const ::std::os::raw::c_char,
246281
message_length: *mut size_t,
247282
) {
248-
ptr_to_ref(future).with_waited_state(|state: &mut CassFutureState| {
283+
clone_arced(future).with_waited_state(|state: &mut CassFutureState| {
249284
let value = &state.value;
250285
let msg = state
251286
.err_string
@@ -267,7 +302,7 @@ pub unsafe extern "C" fn cass_future_free(future_raw: *const CassFuture) {
267302
pub unsafe extern "C" fn cass_future_get_result(
268303
future_raw: *const CassFuture,
269304
) -> *const CassResult {
270-
ptr_to_ref(future_raw)
305+
clone_arced(future_raw)
271306
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassResult>> {
272307
match r.as_ref().ok()? {
273308
CassResultValue::QueryResult(qr) => Some(qr.clone()),
@@ -281,7 +316,7 @@ pub unsafe extern "C" fn cass_future_get_result(
281316
pub unsafe extern "C" fn cass_future_get_error_result(
282317
future_raw: *const CassFuture,
283318
) -> *const CassErrorResult {
284-
ptr_to_ref(future_raw)
319+
clone_arced(future_raw)
285320
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassErrorResult>> {
286321
match r.as_ref().ok()? {
287322
CassResultValue::QueryError(qr) => Some(qr.clone()),
@@ -295,7 +330,7 @@ pub unsafe extern "C" fn cass_future_get_error_result(
295330
pub unsafe extern "C" fn cass_future_get_prepared(
296331
future_raw: *mut CassFuture,
297332
) -> *const CassPrepared {
298-
ptr_to_ref(future_raw)
333+
clone_arced(future_raw)
299334
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassPrepared>> {
300335
match r.as_ref().ok()? {
301336
CassResultValue::Prepared(p) => Some(p.clone()),
@@ -310,7 +345,7 @@ pub unsafe extern "C" fn cass_future_tracing_id(
310345
future: *const CassFuture,
311346
tracing_id: *mut CassUuid,
312347
) -> CassError {
313-
ptr_to_ref(future).with_waited_result(|r: &mut CassFutureResult| match r {
348+
clone_arced(future).with_waited_result(|r: &mut CassFutureResult| match r {
314349
Ok(CassResultValue::QueryResult(result)) => match result.metadata.tracing_id {
315350
Some(id) => {
316351
*tracing_id = CassUuid::from(id);

scylla-rust-wrapper/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ pub mod uuid;
3636
pub mod value;
3737

3838
lazy_static! {
39-
pub static ref RUNTIME: Runtime = Runtime::new().unwrap();
39+
// TODO: Revert to multi-thread. It's just for the sake of review.
40+
// To check that tests pass with current-thread runtime.
41+
pub static ref RUNTIME: Runtime = tokio::runtime::Builder::new_current_thread()
42+
.enable_all()
43+
.build()
44+
.unwrap();
4045
pub static ref LOGGER: RwLock<Logger> = RwLock::new(Logger {
4146
cb: Some(stderr_log_callback),
4247
data: std::ptr::null_mut(),

0 commit comments

Comments
 (0)