Skip to content

Commit a202c03

Browse files
committed
Add (conservative) capture bound on async kernel launch
1 parent c74b542 commit a202c03

File tree

5 files changed

+153
-75
lines changed

5 files changed

+153
-75
lines changed

examples/print/src/main.rs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@ pub enum Action {
2323

2424
#[rust_cuda::kernel::kernel(use link! for impl)]
2525
#[kernel(allow(ptx::local_memory_usage))]
26-
pub fn kernel(action: rust_cuda::kernel::param::PerThreadShallowCopy<Action>) {
26+
pub fn kernel<'a>(
27+
action: rust_cuda::kernel::param::PerThreadShallowCopy<Action>,
28+
_unused: &mut rust_cuda::kernel::param::DeepPerThreadBorrow<
29+
rust_cuda::utils::aliasing::SplitSliceOverCudaThreadsConstStride<
30+
rust_cuda::utils::exchange::buffer::CudaExchangeBuffer<u8, true, true>,
31+
1,
32+
>,
33+
>,
34+
) {
2735
match action {
2836
Action::Print => rust_cuda::device::utils::println!("println! from CUDA kernel"),
2937
Action::Panic => panic!("panic! from CUDA kernel"),
@@ -36,8 +44,10 @@ pub fn kernel(action: rust_cuda::kernel::param::PerThreadShallowCopy<Action>) {
3644
#[cfg(not(target_os = "cuda"))]
3745
fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
3846
// Link the non-generic CUDA kernel
39-
struct KernelPtx;
40-
link! { impl kernel for KernelPtx }
47+
struct KernelPtx<'a> {
48+
_marker: &'a [u8],
49+
}
50+
link! { impl kernel<'a> for KernelPtx }
4151

4252
// Initialize the CUDA API
4353
rust_cuda::deps::rustacuda::init(rust_cuda::deps::rustacuda::CudaFlags::empty())?;
@@ -69,13 +79,57 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
6979
ptx_jit: false,
7080
};
7181

72-
// Launch the CUDA kernel on the stream and synchronise to its completion
73-
println!("Launching print kernel ...");
74-
kernel.launch1(&stream, &config, Action::Print)?;
75-
println!("Launching panic kernel ...");
76-
kernel.launch1(&stream, &config, Action::Panic)?;
77-
println!("Launching alloc error kernel ...");
78-
kernel.launch1(&stream, &config, Action::AllocError)?;
82+
let mut slice = rust_cuda::utils::aliasing::SplitSliceOverCudaThreadsConstStride::<_, 1>::new(
83+
rust_cuda::utils::exchange::buffer::CudaExchangeBuffer::<_, true, true>::from_vec(vec![
84+
1_u8, 2, 3,
85+
])?,
86+
);
87+
88+
rust_cuda::lend::LendToCuda::lend_to_cuda_mut(&mut slice, |mut slice| {
89+
// let mut slice_async = slice.as_async(&stream);
90+
91+
// Launch the CUDA kernel on the stream and synchronise to its completion
92+
93+
println!("Launching print kernel ...");
94+
{
95+
let mut slice_async = slice.as_async(&stream);
96+
let slice_async_mut = slice_async.proj_mut();
97+
98+
let capture = rust_cuda::kernel::Capture;
99+
let r#async =
100+
kernel.launch2_async(&stream, &config, &capture, Action::Print, slice_async_mut)?;
101+
r#async.synchronize()?;
102+
}
103+
104+
println!("Launching panic kernel ...");
105+
{
106+
let mut slice_async = slice.as_async(&stream);
107+
let slice_async_mut = slice_async.proj_mut();
108+
109+
let capture = rust_cuda::kernel::Capture;
110+
let r#async =
111+
kernel.launch2_async(&stream, &config, &capture, Action::Panic, slice_async_mut)?;
112+
r#async.synchronize()?;
113+
}
114+
115+
println!("Launching alloc error kernel ...");
116+
{
117+
let mut slice_async = slice.as_async(&stream);
118+
let slice_async_mut = slice_async.proj_mut();
119+
120+
let capture = rust_cuda::kernel::Capture;
121+
let r#async = kernel.launch2_async(
122+
&stream,
123+
&config,
124+
&capture,
125+
Action::AllocError,
126+
slice_async_mut,
127+
)?;
128+
r#async.synchronize()?;
129+
}
130+
131+
Ok(())
132+
})?;
79133

80134
Ok(())
81135
}

rust-cuda-kernel/src/kernel/wrapper/generate/cuda_generic_function.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ pub(in super::super) fn quote_cuda_generic_function(
8282
)
8383
.collect::<Vec<_>>();
8484

85+
let generic_start_token = generic_start_token.unwrap_or_default();
86+
let generic_close_token = generic_close_token.unwrap_or_default();
87+
8588
quote! {
8689
#[cfg(target_os = "cuda")]
8790
#(#func_attrs)*

src/host/mod.rs

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::{
2222
DeviceConstPointer, DeviceConstRef, DeviceMutPointer, DeviceMutRef, DeviceOwnedPointer,
2323
DeviceOwnedRef,
2424
},
25-
r#async::{Async, NoCompletion},
25+
r#async::{Async, AsyncProj, NoCompletion},
2626
},
2727
};
2828

@@ -194,17 +194,16 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
194194
pub fn as_async<'b, 'stream>(
195195
&'b mut self,
196196
stream: &'stream Stream,
197-
) -> Async<'b, 'stream, HostAndDeviceMutRef<'b, T>, NoCompletion>
197+
) -> AsyncProj<'b, 'stream, HostAndDeviceMutRef<'b, T>>
198198
where
199199
'a: 'b,
200200
{
201-
Async::ready(
202-
HostAndDeviceMutRef {
203-
device_box: self.device_box,
204-
host_ref: self.host_ref,
205-
},
206-
stream,
207-
)
201+
let _ = stream;
202+
203+
AsyncProj::new(HostAndDeviceMutRef {
204+
device_box: self.device_box,
205+
host_ref: self.host_ref,
206+
})
208207
}
209208
}
210209

@@ -293,17 +292,16 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceConstRef<'a, T>
293292
pub const fn as_async<'b, 'stream>(
294293
&'b self,
295294
stream: &'stream Stream,
296-
) -> Async<'b, 'stream, HostAndDeviceConstRef<'b, T>, NoCompletion>
295+
) -> AsyncProj<'b, 'stream, HostAndDeviceConstRef<'b, T>>
297296
where
298297
'a: 'b,
299298
{
300-
Async::ready(
301-
HostAndDeviceConstRef {
302-
device_box: self.device_box,
303-
host_ref: self.host_ref,
304-
},
305-
stream,
306-
)
299+
let _ = stream;
300+
301+
AsyncProj::new(HostAndDeviceConstRef {
302+
device_box: self.device_box,
303+
host_ref: self.host_ref,
304+
})
307305
}
308306
}
309307

0 commit comments

Comments
 (0)