Skip to content

Commit 5114279

Browse files
committed
Backup of progress on compile-time PTX checking
1 parent 10fecf6 commit 5114279

File tree

8 files changed

+152
-8
lines changed

8 files changed

+152
-8
lines changed

examples/single-source/src/main.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
extern crate alloc;
1717

18+
#[cfg(target_os = "cuda")]
1819
use rc::utils::shared::r#static::ThreadBlockShared;
1920

2021
#[cfg(not(target_os = "cuda"))]
@@ -50,23 +51,25 @@ pub fn kernel<'a, T: rc::common::RustToCuda>(
5051
#[kernel(pass = LendRustToCuda)] _z: &ShallowCopy<Wrapper<T>>,
5152
#[kernel(pass = SafeDeviceCopy, jit)] _v @ _w: &'a core::sync::atomic::AtomicU64,
5253
#[kernel(pass = LendRustToCuda)] _: Wrapper<T>,
53-
#[kernel(pass = SafeDeviceCopy)] Tuple(_s, mut __t): Tuple,
54-
#[kernel(pass = LendRustToCuda)] shared3: ThreadBlockShared<u32>,
54+
#[kernel(pass = SafeDeviceCopy)] Tuple(s, mut __t): Tuple,
55+
// #[kernel(pass = LendRustToCuda)] shared3: ThreadBlockShared<u32>,
5556
) where
5657
<T as rc::common::RustToCuda>::CudaRepresentation: rc::safety::StackOnly,
5758
{
5859
let shared: ThreadBlockShared<[Tuple; 3]> = ThreadBlockShared::new_uninit();
5960
let shared2: ThreadBlockShared<[Tuple; 3]> = ThreadBlockShared::new_uninit();
6061

62+
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
6163
unsafe {
62-
(*shared.as_mut_ptr().cast::<Tuple>().add(1)).0 = 42;
64+
(*shared.as_mut_ptr().cast::<Tuple>().add(1)).0 = (f64::from(s) * 2.0) as u32;
6365
}
6466
unsafe {
6567
(*shared2.as_mut_ptr().cast::<Tuple>().add(2)).1 = 24;
6668
}
67-
unsafe {
68-
*shared3.as_mut_ptr() = 12;
69-
}
69+
unsafe { core::arch::asm!("hi") }
70+
// unsafe {
71+
// *shared3.as_mut_ptr() = 12;
72+
// }
7073
}
7174

7275
#[cfg(not(target_os = "cuda"))]

rust-cuda-derive/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44
authors = ["Juniper Tyree <[email protected]>"]
55
license = "MIT OR Apache-2.0"
66
edition = "2021"
7+
links = "libnvptxcompiler_static"
78

89
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
910

@@ -24,3 +25,4 @@ colored = "2.0"
2425

2526
seahash = "4.1"
2627
ptx-builder = { git = "https://github.com/juntyr/rust-ptx-builder", rev = "1f1f49d" }
28+
ptx_compiler = "0.1"

rust-cuda-derive/build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
fn main() {
2+
println!("cargo:rustc-link-lib=nvptxcompiler_static");
3+
}

rust-cuda-derive/src/kernel/link/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::path::PathBuf;
33
#[allow(clippy::module_name_repetitions)]
44
pub(super) struct LinkKernelConfig {
55
pub(super) kernel: syn::Ident,
6+
pub(super) kernel_hash: syn::Ident,
67
pub(super) args: syn::Ident,
78
pub(super) crate_name: String,
89
pub(super) crate_path: PathBuf,
@@ -12,6 +13,7 @@ pub(super) struct LinkKernelConfig {
1213
impl syn::parse::Parse for LinkKernelConfig {
1314
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1415
let kernel: syn::Ident = input.parse()?;
16+
let kernel_hash: syn::Ident = input.parse()?;
1517
let args: syn::Ident = input.parse()?;
1618
let name: syn::LitStr = input.parse()?;
1719
let path: syn::LitStr = input.parse()?;
@@ -37,6 +39,7 @@ impl syn::parse::Parse for LinkKernelConfig {
3739

3840
Ok(Self {
3941
kernel,
42+
kernel_hash,
4043
args,
4144
crate_name: name.value(),
4245
crate_path: PathBuf::from(path.value()),

rust-cuda-derive/src/kernel/link/mod.rs

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
use std::{
2-
env, fs,
2+
env,
3+
ffi::CString,
4+
fs,
35
io::{Read, Write},
6+
mem::MaybeUninit,
7+
os::raw::c_int,
48
path::{Path, PathBuf},
9+
ptr::addr_of_mut,
510
sync::atomic::{AtomicBool, Ordering},
611
};
712

@@ -11,6 +16,7 @@ use ptx_builder::{
1116
builder::{BuildStatus, Builder, MessageFormat, Profile},
1217
error::{BuildErrorKind, Error, Result},
1318
};
19+
use ptx_compiler::sys::size_t;
1420

1521
use super::utils::skip_kernel_compilation;
1622

@@ -56,6 +62,7 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
5662

5763
let LinkKernelConfig {
5864
kernel,
65+
kernel_hash,
5966
args,
6067
crate_name,
6168
crate_path,
@@ -192,6 +199,119 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
192199
kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
193200
}
194201

202+
let mut compiler = MaybeUninit::uninit();
203+
let r = unsafe {
204+
ptx_compiler::sys::nvPTXCompilerCreate(
205+
compiler.as_mut_ptr(),
206+
kernel_ptx.len() as size_t,
207+
kernel_ptx.as_ptr().cast(),
208+
)
209+
};
210+
emit_call_site_warning!("PTX compiler create result {}", r);
211+
let compiler = unsafe { compiler.assume_init() };
212+
213+
let mut major = 0;
214+
let mut minor = 0;
215+
let r = unsafe {
216+
ptx_compiler::sys::nvPTXCompilerGetVersion(addr_of_mut!(major), addr_of_mut!(minor))
217+
};
218+
emit_call_site_warning!("PTX version result {}", r);
219+
emit_call_site_warning!("PTX compiler version {}.{}", major, minor);
220+
221+
let kernel_name = if specialisation.is_empty() {
222+
format!("{kernel_hash}_kernel")
223+
} else {
224+
format!(
225+
"{kernel_hash}_kernel_{:016x}",
226+
seahash::hash(specialisation.as_bytes())
227+
)
228+
};
229+
230+
let options = vec![
231+
CString::new("--entry").unwrap(),
232+
CString::new(kernel_name).unwrap(),
233+
CString::new("--verbose").unwrap(),
234+
CString::new("--warn-on-double-precision-use").unwrap(),
235+
CString::new("--warn-on-local-memory-usage").unwrap(),
236+
CString::new("--warn-on-spills").unwrap(),
237+
];
238+
let options_ptrs = options.iter().map(|o| o.as_ptr()).collect::<Vec<_>>();
239+
240+
let r = unsafe {
241+
ptx_compiler::sys::nvPTXCompilerCompile(
242+
compiler,
243+
options_ptrs.len() as c_int,
244+
options_ptrs.as_ptr().cast(),
245+
)
246+
};
247+
emit_call_site_warning!("PTX compile result {}", r);
248+
249+
let mut info_log_size = 0;
250+
let r = unsafe {
251+
ptx_compiler::sys::nvPTXCompilerGetInfoLogSize(compiler, addr_of_mut!(info_log_size))
252+
};
253+
emit_call_site_warning!("PTX info log size result {}", r);
254+
#[allow(clippy::cast_possible_truncation)]
255+
let mut info_log: Vec<u8> = Vec::with_capacity(info_log_size as usize);
256+
if info_log_size > 0 {
257+
let r = unsafe {
258+
ptx_compiler::sys::nvPTXCompilerGetInfoLog(compiler, info_log.as_mut_ptr().cast())
259+
};
260+
emit_call_site_warning!("PTX info log content result {}", r);
261+
#[allow(clippy::cast_possible_truncation)]
262+
unsafe {
263+
info_log.set_len(info_log_size as usize);
264+
}
265+
}
266+
let info_log = String::from_utf8_lossy(&info_log);
267+
268+
let mut error_log_size = 0;
269+
let r = unsafe {
270+
ptx_compiler::sys::nvPTXCompilerGetErrorLogSize(compiler, addr_of_mut!(error_log_size))
271+
};
272+
emit_call_site_warning!("PTX error log size result {}", r);
273+
#[allow(clippy::cast_possible_truncation)]
274+
let mut error_log: Vec<u8> = Vec::with_capacity(error_log_size as usize);
275+
if error_log_size > 0 {
276+
let r = unsafe {
277+
ptx_compiler::sys::nvPTXCompilerGetErrorLog(compiler, error_log.as_mut_ptr().cast())
278+
};
279+
emit_call_site_warning!("PTX error log content result {}", r);
280+
#[allow(clippy::cast_possible_truncation)]
281+
unsafe {
282+
error_log.set_len(error_log_size as usize);
283+
}
284+
}
285+
let error_log = String::from_utf8_lossy(&error_log);
286+
287+
// Ensure the compiler is not dropped
288+
let mut compiler = MaybeUninit::new(compiler);
289+
let r = unsafe { ptx_compiler::sys::nvPTXCompilerDestroy(compiler.as_mut_ptr()) };
290+
emit_call_site_warning!("PTX compiler destroy result {}", r);
291+
292+
if !info_log.is_empty() {
293+
emit_call_site_warning!("PTX compiler info log:\n{}", info_log);
294+
}
295+
if !error_log.is_empty() {
296+
let mut max_lines = kernel_ptx.chars().filter(|c| *c == '\n').count() + 1;
297+
let mut indent = 0;
298+
while max_lines > 0 {
299+
max_lines /= 10;
300+
indent += 1;
301+
}
302+
303+
abort_call_site!(
304+
"PTX compiler error log:\n{}\nPTX source:\n{}",
305+
error_log,
306+
kernel_ptx
307+
.lines()
308+
.enumerate()
309+
.map(|(i, l)| format!("{:indent$}| {l}", i + 1))
310+
.collect::<Vec<_>>()
311+
.join("\n")
312+
);
313+
}
314+
195315
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
196316
}
197317

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/get_ptx_str.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ pub(super) fn quote_get_ptx_str(
8484
quote! {
8585
fn get_ptx_str() -> &'static str {
8686
#crate_path::host::link_kernel!{
87-
#func_ident #args #crate_name #crate_manifest_dir #generic_start_token
87+
#func_ident #func_ident_hash #args #crate_name #crate_manifest_dir #generic_start_token
8888
#($#macro_type_ids),*
8989
#generic_close_token
9090
}

src/safety/device_copy.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,11 @@ mod sealed {
1919
for crate::utils::device_copy::SafeDeviceCopyWrapper<T>
2020
{
2121
}
22+
23+
// Only unsafe aliasing is possible since both only expose raw pointers
24+
// impl<T: 'static> SafeDeviceCopy for
25+
// crate::utils::shared::r#static::ThreadBlockShared<T> {}
26+
// impl<T: 'static + ~const const_type_layout::TypeGraphLayout>
27+
// SafeDeviceCopy for crate::utils::shared::slice::ThreadBlockSharedSlice<T>
28+
// {}
2229
}

src/safety/no_aliasing.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,10 @@ mod private {
2222
{
2323
}
2424
impl<T> NoAliasing for crate::utils::aliasing::SplitSliceOverCudaThreadsDynamicStride<T> {}
25+
26+
// Only unsafe aliasing is possible since both only expose raw pointers
27+
// impl<T: 'static> NoAliasing for
28+
// crate::utils::shared::r#static::ThreadBlockShared<T> {}
29+
// impl<T: 'static + ~const const_type_layout::TypeGraphLayout> NoAliasing
30+
// for crate::utils::shared::slice::ThreadBlockSharedSlice<T> {}
2531
}

0 commit comments

Comments
 (0)