Skip to content

Commit 139adce

Browse files
committed
Ensure async launch mutable borrow safety with barriers on use and stream move
1 parent c74b542 commit 139adce

File tree

20 files changed

+358
-172
lines changed

20 files changed

+358
-172
lines changed

examples/print/src/main.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
5555
);
5656

5757
// Create a new CUDA stream to submit kernels to
58-
let stream =
58+
let mut stream =
5959
rust_cuda::host::CudaDropWrapper::from(rust_cuda::deps::rustacuda::stream::Stream::new(
6060
rust_cuda::deps::rustacuda::stream::StreamFlags::NON_BLOCKING,
6161
None,
@@ -70,12 +70,14 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
7070
};
7171

7272
// Launch the CUDA kernel on the stream and synchronise to its completion
73-
println!("Launching print kernel ...");
74-
kernel.launch1(&stream, &config, Action::Print)?;
75-
println!("Launching panic kernel ...");
76-
kernel.launch1(&stream, &config, Action::Panic)?;
77-
println!("Launching alloc error kernel ...");
78-
kernel.launch1(&stream, &config, Action::AllocError)?;
73+
rust_cuda::host::Stream::with(&mut stream, |stream| {
74+
println!("Launching print kernel ...");
75+
kernel.launch1(stream, &config, Action::Print)?;
76+
println!("Launching panic kernel ...");
77+
kernel.launch1(stream, &config, Action::Panic)?;
78+
println!("Launching alloc error kernel ...");
79+
kernel.launch1(stream, &config, Action::AllocError)
80+
})?;
7981

8082
Ok(())
8183
}

rust-cuda-derive/src/rust_to_cuda/impl.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ pub fn rust_to_cuda_async_trait(
191191
unsafe fn borrow_async<'stream, CudaAllocType: #crate_path::alloc::CudaAlloc>(
192192
&self,
193193
alloc: CudaAllocType,
194-
stream: &'stream #crate_path::deps::rustacuda::stream::Stream,
194+
stream: &'stream #crate_path::host::Stream,
195195
) -> #crate_path::deps::rustacuda::error::CudaResult<(
196196
#crate_path::utils::r#async::Async<
197197
'_, 'stream,
@@ -219,7 +219,7 @@ pub fn rust_to_cuda_async_trait(
219219
alloc: #crate_path::alloc::CombinedCudaAlloc<
220220
Self::CudaAllocationAsync, CudaAllocType
221221
>,
222-
stream: &'stream #crate_path::deps::rustacuda::stream::Stream,
222+
stream: &'stream #crate_path::host::Stream,
223223
) -> #crate_path::deps::rustacuda::error::CudaResult<(
224224
#crate_path::utils::r#async::Async<
225225
'a, 'stream,

rust-cuda-kernel/src/kernel/wrapper/generate/cuda_generic_function.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ pub(in super::super) fn quote_cuda_generic_function(
8282
)
8383
.collect::<Vec<_>>();
8484

85+
let generic_start_token = generic_start_token.unwrap_or_default();
86+
let generic_close_token = generic_close_token.unwrap_or_default();
87+
8588
quote! {
8689
#[cfg(target_os = "cuda")]
8790
#(#func_attrs)*

src/host/mod.rs

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use rustacuda::{
1111
event::Event,
1212
memory::{CopyDestination, DeviceBox, DeviceBuffer, LockedBox, LockedBuffer},
1313
module::Module,
14-
stream::Stream,
1514
};
1615

1716
use crate::{
@@ -26,6 +25,33 @@ use crate::{
2625
},
2726
};
2827

28+
#[repr(transparent)]
29+
pub struct Stream {
30+
stream: rustacuda::stream::Stream,
31+
}
32+
33+
impl Deref for Stream {
34+
type Target = rustacuda::stream::Stream;
35+
36+
fn deref(&self) -> &Self::Target {
37+
&self.stream
38+
}
39+
}
40+
41+
impl Stream {
42+
pub fn with<O>(
43+
stream: &mut rustacuda::stream::Stream,
44+
inner: impl for<'stream> FnOnce(&'stream Self) -> O,
45+
) -> O {
46+
// Safety:
47+
// - Stream is a newtype wrapper around rustacuda::stream::Stream
48+
// - we forge a unique lifetime for a unique reference
49+
let stream = unsafe { &*std::ptr::from_ref(stream).cast() };
50+
51+
inner(stream)
52+
}
53+
}
54+
2955
pub trait CudaDroppable: Sized {
3056
#[allow(clippy::missing_errors_doc)]
3157
fn drop(val: Self) -> Result<(), (rustacuda::error::CudaError, Self)>;
@@ -88,7 +114,7 @@ impl<T: rustacuda_core::DeviceCopy> CudaDroppable for LockedBuffer<T> {
88114
}
89115

90116
macro_rules! impl_sealed_drop_value {
91-
($type:ident) => {
117+
($type:ty) => {
92118
impl CudaDroppable for $type {
93119
fn drop(val: Self) -> Result<(), (CudaError, Self)> {
94120
Self::drop(val)
@@ -98,7 +124,7 @@ macro_rules! impl_sealed_drop_value {
98124
}
99125

100126
impl_sealed_drop_value!(Module);
101-
impl_sealed_drop_value!(Stream);
127+
impl_sealed_drop_value!(rustacuda::stream::Stream);
102128
impl_sealed_drop_value!(Context);
103129
impl_sealed_drop_value!(Event);
104130

@@ -142,7 +168,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
142168
/// # Safety
143169
///
144170
/// `device_box` must contain EXACTLY the device copy of `host_ref`
145-
pub unsafe fn new_unchecked(
171+
pub(crate) unsafe fn new_unchecked(
146172
device_box: &'a mut DeviceBox<DeviceCopyWithPortableBitSemantics<T>>,
147173
host_ref: &'a mut T,
148174
) -> Self {
@@ -180,7 +206,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
180206
}
181207

182208
#[must_use]
183-
pub fn as_mut<'b>(&'b mut self) -> HostAndDeviceMutRef<'b, T>
209+
pub fn into_mut<'b>(self) -> HostAndDeviceMutRef<'b, T>
184210
where
185211
'a: 'b,
186212
{
@@ -191,20 +217,14 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
191217
}
192218

193219
#[must_use]
194-
pub fn as_async<'b, 'stream>(
195-
&'b mut self,
220+
pub fn into_async<'b, 'stream>(
221+
self,
196222
stream: &'stream Stream,
197223
) -> Async<'b, 'stream, HostAndDeviceMutRef<'b, T>, NoCompletion>
198224
where
199225
'a: 'b,
200226
{
201-
Async::ready(
202-
HostAndDeviceMutRef {
203-
device_box: self.device_box,
204-
host_ref: self.host_ref,
205-
},
206-
stream,
207-
)
227+
Async::ready(self.into_mut(), stream)
208228
}
209229
}
210230

@@ -253,7 +273,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceConstRef<'a, T>
253273
/// # Safety
254274
///
255275
/// `device_box` must contain EXACTLY the device copy of `host_ref`
256-
pub const unsafe fn new_unchecked(
276+
pub(crate) const unsafe fn new_unchecked(
257277
device_box: &'a DeviceBox<DeviceCopyWithPortableBitSemantics<T>>,
258278
host_ref: &'a T,
259279
) -> Self {

src/kernel/mod.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use rustacuda::{
1111
error::{CudaError, CudaResult},
1212
function::Function,
1313
module::Module,
14-
stream::Stream,
1514
};
1615

1716
#[cfg(feature = "kernel")]
@@ -27,6 +26,8 @@ mod ptx_jit;
2726
#[cfg(feature = "host")]
2827
use ptx_jit::{PtxJITCompiler, PtxJITResult};
2928

29+
#[cfg(feature = "host")]
30+
use crate::host::Stream;
3031
use crate::safety::PortableBitSemantics;
3132

3233
pub mod param;
@@ -109,7 +110,7 @@ pub trait CudaKernelParameter: sealed::Sealed {
109110
#[allow(clippy::missing_errors_doc)] // FIXME
110111
fn with_new_async<'stream, 'param, O, E: From<rustacuda::error::CudaError>>(
111112
param: Self::SyncHostType,
112-
stream: &'stream rustacuda::stream::Stream,
113+
stream: &'stream crate::host::Stream,
113114
inner: impl WithNewAsync<'stream, Self, O, E>,
114115
) -> Result<O, E>
115116
where
@@ -206,7 +207,9 @@ macro_rules! impl_launcher_launch {
206207
pub fn $launch_async<$($T: CudaKernelParameter),*>(
207208
&mut self,
208209
$($arg: $T::AsyncHostType<'stream, '_>),*
209-
) -> CudaResult<()>
210+
) -> CudaResult<crate::utils::r#async::Async<
211+
'static, 'stream, (), crate::utils::r#async::NoCompletion,
212+
>>
210213
where
211214
Kernel: FnOnce(&mut Launcher<'stream, '_, Kernel>, $($T),*),
212215
{
@@ -375,13 +378,10 @@ macro_rules! impl_typed_kernel_launch {
375378
config,
376379
$($arg,)*
377380
|kernel, stream, config, $($arg),*| {
378-
let result = kernel.$launch_async::<$($T),*>(stream, config, $($arg),*);
381+
let r#async = kernel.$launch_async::<$($T),*>(stream, config, $($arg),*)?;
379382

380383
// important: always synchronise here, this function is sync!
381-
match (stream.synchronize(), result) {
382-
(Ok(()), result) => result,
383-
(Err(_), Err(err)) | (Err(err), Ok(())) => Err(err),
384-
}
384+
r#async.synchronize()
385385
},
386386
)
387387
}
@@ -422,7 +422,29 @@ macro_rules! impl_typed_kernel_launch {
422422
stream: &'stream Stream,
423423
config: &LaunchConfig,
424424
$($arg: $T::AsyncHostType<'stream, '_>),*
425-
) -> CudaResult<()>
425+
) -> CudaResult<crate::utils::r#async::Async<
426+
'static, 'stream, (), crate::utils::r#async::NoCompletion,
427+
>>
428+
// launch_async does not need to capture its parameters until kernel completion:
429+
// - moved parameters are moved and cannot be used again, deallocation will sync
430+
// - immutably borrowed parameters can be shared across multiple kernel launches
431+
// - mutably borrowed parameters are more tricky:
432+
// - Rust's borrowing rules ensure that a single mutable reference cannot be
433+
// passed into multiple parameters of the kernel (no mutable aliasing)
434+
// - CUDA guarantees that kernels launched on the same stream are executed
435+
// sequentially, so even immediate resubmissions for the same mutable data
436+
// will not have temporally overlapping mutation on the same stream
437+
// - however, we have to guarantee that mutable data cannot be used on several
438+
// different streams at the same time
439+
// - Async::move_to_stream always adds a synchronisation barrier between the
440+
// old and the new stream to ensure that all uses on the old stream happen
441+
// strictly before all uses on the new stream
442+
// - async launches take AsyncProj<&mut HostAndDeviceMutRef<..>>, which either
443+
// captures an Async, which must be moved to a different stream explicitly,
444+
// or contains data that cannot async move to a different stream without
445+
// - any use of a mutable borrow in an async kernel launch adds a sync barrier
446+
// on the launch stream s.t. the borrow is only complete once the kernel has
447+
// completed
426448
where
427449
Kernel: FnOnce(&mut Launcher<'stream, 'kernel, Kernel>, $($T),*),
428450
{
@@ -454,7 +476,11 @@ macro_rules! impl_typed_kernel_launch {
454476
&mut $T::async_to_ffi($arg, sealed::Token)?
455477
).cast::<core::ffi::c_void>()),*
456478
],
457-
) }
479+
) }?;
480+
481+
crate::utils::r#async::Async::pending(
482+
(), stream, crate::utils::r#async::NoCompletion,
483+
)
458484
}
459485
};
460486
(impl $func:ident () + ($($other:expr),*) $inner:block) => {

0 commit comments

Comments
 (0)