Skip to content

refactor(cust_raw): consolidate CUDA, cuDNN, OptiX bindgen and remove find_cuda_helper #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ jobs:
run: cargo fmt --all -- --check

- name: Build
run: cargo build --workspace --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
run: cargo build --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*"

# Don't currently test because many tests rely on the system having a CUDA GPU
# - name: Test
Expand All @@ -112,9 +112,9 @@ jobs:
if: contains(matrix.os, 'ubuntu')
env:
RUSTFLAGS: -Dwarnings
run: cargo clippy --workspace --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
run: cargo clippy --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*"

- name: Check documentation
env:
RUSTDOCFLAGS: -Dwarnings
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ members = [
"examples/optix/*",
"examples/cuda/cpu/*",
"examples/cuda/gpu/*",

]

exclude = [
"crates/optix/examples/common"
"crates/optix/examples/common",
]

[profile.dev.package.rustc_codegen_nvvm]
Expand Down
2 changes: 1 addition & 1 deletion crates/blastoff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA"

[dependencies]
bitflags = "2.8"
cublas_sys = { version = "0.1", path = "../cublas_sys" }
cust = { version = "0.3", path = "../cust", features = ["impl_num_complex"] }
cust_raw = { path = "../cust_raw", features = ["cublas"] }
num-complex = "0.4.6"
half = { version = "2.4.1", optional = true }

Expand Down
57 changes: 31 additions & 26 deletions crates/blastoff/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
use crate::{error::*, sys};
use cust::stream::Stream;
use std::ffi::CString;
use std::mem::{self, MaybeUninit};
use std::os::raw::c_char;
use std::ptr;

type Result<T, E = Error> = std::result::Result<T, E>;
use cust::stream::Stream;
use cust_raw::cublas_sys;
use cust_raw::driver_sys;

use super::error::DropResult;
use super::error::ToResult as _;

type Result<T, E = super::error::Error> = std::result::Result<T, E>;

bitflags::bitflags! {
/// Configures precision levels for the math in cuBLAS.
#[derive(Default)]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MathMode: u32 {
/// Highest performance mode which uses compute and intermediate storage precisions
/// with at least the same number of mantissa and exponent bits as requested. Will
Expand Down Expand Up @@ -68,7 +73,7 @@ bitflags::bitflags! {
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
#[derive(Debug)]
pub struct CublasContext {
pub(crate) raw: sys::v2::cublasHandle_t,
pub(crate) raw: cublas_sys::cublasHandle_t,
}

impl CublasContext {
Expand All @@ -87,10 +92,10 @@ impl CublasContext {
pub fn new() -> Result<Self> {
let mut raw = MaybeUninit::uninit();
unsafe {
sys::v2::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
sys::v2::cublasSetPointerMode_v2(
cublas_sys::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
cublas_sys::cublasSetPointerMode_v2(
raw.assume_init(),
sys::v2::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
)
.to_result()?;
Ok(Self {
Expand All @@ -107,7 +112,7 @@ impl CublasContext {

unsafe {
let inner = mem::replace(&mut ctx.raw, ptr::null_mut());
match sys::v2::cublasDestroy_v2(inner).to_result() {
match cublas_sys::cublasDestroy_v2(inner).to_result() {
Ok(()) => {
mem::forget(ctx);
Ok(())
Expand All @@ -122,7 +127,7 @@ impl CublasContext {
let mut raw = MaybeUninit::<u32>::uninit();
unsafe {
// getVersion can't fail
sys::v2::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
cublas_sys::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
.to_result()
.unwrap();

Expand All @@ -140,17 +145,17 @@ impl CublasContext {
) -> Result<T> {
unsafe {
// cudaStream_t is the same as CUstream
sys::v2::cublasSetStream_v2(
cublas_sys::cublasSetStream_v2(
self.raw,
mem::transmute::<*mut cust::sys::CUstream_st, *mut cublas_sys::v2::CUstream_st>(
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
stream.as_inner(),
),
)
.to_result()?;
let res = func(self)?;
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
// execute a raw sys function with the context's handle.
sys::v2::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
cublas_sys::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
Ok(res)
}
}
Expand Down Expand Up @@ -180,12 +185,12 @@ impl CublasContext {
/// ```
pub fn set_atomics_mode(&self, allowed: bool) -> Result<()> {
unsafe {
Ok(sys::v2::cublasSetAtomicsMode(
Ok(cublas_sys::cublasSetAtomicsMode(
self.raw,
if allowed {
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
} else {
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
},
)
.to_result()?)
Expand All @@ -210,10 +215,10 @@ impl CublasContext {
pub fn get_atomics_mode(&self) -> Result<bool> {
let mut mode = MaybeUninit::uninit();
unsafe {
sys::v2::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
cublas_sys::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
Ok(match mode.assume_init() {
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
})
}
}
Expand All @@ -233,9 +238,9 @@ impl CublasContext {
/// ```
pub fn set_math_mode(&self, math_mode: MathMode) -> Result<()> {
unsafe {
Ok(sys::v2::cublasSetMathMode(
Ok(cublas_sys::cublasSetMathMode(
self.raw,
mem::transmute::<u32, cublas_sys::v2::cublasMath_t>(math_mode.bits()),
mem::transmute::<u32, cublas_sys::cublasMath_t>(math_mode.bits()),
)
.to_result()?)
}
Expand All @@ -258,7 +263,7 @@ impl CublasContext {
pub fn get_math_mode(&self) -> Result<MathMode> {
let mut mode = MaybeUninit::uninit();
unsafe {
sys::v2::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
cublas_sys::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
Ok(MathMode::from_bits(mode.assume_init() as u32)
.expect("Invalid MathMode from cuBLAS"))
}
Expand Down Expand Up @@ -298,7 +303,7 @@ impl CublasContext {
let path = log_file_name.map(|p| CString::new(p).expect("nul in log_file_name"));
let path_ptr = path.map_or(ptr::null(), |s| s.as_ptr());

sys::v2::cublasLoggerConfigure(
cublas_sys::cublasLoggerConfigure(
enable as i32,
log_to_stdout as i32,
log_to_stderr as i32,
Expand All @@ -315,7 +320,7 @@ impl CublasContext {
///
/// The callback must not panic and unwind.
pub unsafe fn set_logger_callback(callback: Option<unsafe extern "C" fn(*const c_char)>) {
sys::v2::cublasSetLoggerCallback(callback)
cublas_sys::cublasSetLoggerCallback(callback)
.to_result()
.unwrap();
}
Expand All @@ -324,7 +329,7 @@ impl CublasContext {
pub fn get_logger_callback() -> Option<unsafe extern "C" fn(*const c_char)> {
let mut cb = MaybeUninit::uninit();
unsafe {
sys::v2::cublasGetLoggerCallback(cb.as_mut_ptr())
cublas_sys::cublasGetLoggerCallback(cb.as_mut_ptr())
.to_result()
.unwrap();
cb.assume_init()
Expand All @@ -335,7 +340,7 @@ impl CublasContext {
impl Drop for CublasContext {
fn drop(&mut self) {
unsafe {
sys::v2::cublasDestroy_v2(self.raw);
cublas_sys::cublasDestroy_v2(self.raw);
}
}
}
51 changes: 27 additions & 24 deletions crates/blastoff/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::sys;
use cust::error::CudaError;
use std::{ffi::CStr, fmt::Display};

use cust::error::CudaError;
use cust_raw::cublas_sys;

/// Result that contains the un-dropped value on error.
pub type DropResult<T> = std::result::Result<(), (CublasError, T)>;

Expand All @@ -24,7 +25,7 @@ impl std::error::Error for CublasError {}
impl Display for CublasError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unsafe {
let ptr = sys::v2::cublasGetStatusString(self.into_raw());
let ptr = cublas_sys::cublasGetStatusString(self.into_raw());
let cow = CStr::from_ptr(ptr).to_string_lossy();
f.write_str(cow.as_ref())
}
Expand All @@ -35,39 +36,41 @@ pub trait ToResult {
fn to_result(self) -> Result<(), CublasError>;
}

impl ToResult for sys::v2::cublasStatus_t {
impl ToResult for cublas_sys::cublasStatus_t {
fn to_result(self) -> Result<(), CublasError> {
use cust_raw::cublas_sys::cublasStatus_t::*;
use CublasError::*;

Err(match self {
sys::v2::cublasStatus_t::CUBLAS_STATUS_SUCCESS => return Ok(()),
sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => NotInitialized,
sys::v2::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => AllocFailed,
sys::v2::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE => InvalidValue,
sys::v2::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => ArchMismatch,
sys::v2::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR => MappingError,
sys::v2::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => ExecutionFailed,
sys::v2::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR => InternalError,
sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED => NotSupported,
sys::v2::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR => LicenseError,
CUBLAS_STATUS_SUCCESS => return Ok(()),
CUBLAS_STATUS_NOT_INITIALIZED => NotInitialized,
CUBLAS_STATUS_ALLOC_FAILED => AllocFailed,
CUBLAS_STATUS_INVALID_VALUE => InvalidValue,
CUBLAS_STATUS_ARCH_MISMATCH => ArchMismatch,
CUBLAS_STATUS_MAPPING_ERROR => MappingError,
CUBLAS_STATUS_EXECUTION_FAILED => ExecutionFailed,
CUBLAS_STATUS_INTERNAL_ERROR => InternalError,
CUBLAS_STATUS_NOT_SUPPORTED => NotSupported,
CUBLAS_STATUS_LICENSE_ERROR => LicenseError,
})
}
}

impl CublasError {
pub fn into_raw(self) -> sys::v2::cublasStatus_t {
pub fn into_raw(self) -> cublas_sys::cublasStatus_t {
use cust_raw::cublas_sys::cublasStatus_t::*;
use CublasError::*;

match self {
NotInitialized => sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED,
AllocFailed => sys::v2::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED,
InvalidValue => sys::v2::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE,
ArchMismatch => sys::v2::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH,
MappingError => sys::v2::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR,
ExecutionFailed => sys::v2::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED,
InternalError => sys::v2::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR,
NotSupported => sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED,
LicenseError => sys::v2::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR,
NotInitialized => CUBLAS_STATUS_NOT_INITIALIZED,
AllocFailed => CUBLAS_STATUS_ALLOC_FAILED,
InvalidValue => CUBLAS_STATUS_INVALID_VALUE,
ArchMismatch => CUBLAS_STATUS_ARCH_MISMATCH,
MappingError => CUBLAS_STATUS_MAPPING_ERROR,
ExecutionFailed => CUBLAS_STATUS_EXECUTION_FAILED,
InternalError => CUBLAS_STATUS_INTERNAL_ERROR,
NotSupported => CUBLAS_STATUS_NOT_SUPPORTED,
LicenseError => CUBLAS_STATUS_LICENSE_ERROR,
}
}
}
Expand Down
28 changes: 14 additions & 14 deletions crates/blastoff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#![allow(clippy::too_many_arguments)]
#![cfg_attr(docsrs, feature(doc_cfg))]

pub use cublas_sys as sys;
pub use cust_raw::cublas_sys;
use num_complex::{Complex32, Complex64};

pub use context::*;
Expand Down Expand Up @@ -39,34 +39,34 @@ pub trait BlasDatatype: private::Sealed + cust::memory::DeviceCopy {
/// The corresponding float type. For complex numbers this means their backing
/// precision, and for floats it is just themselves.
type FloatTy: Float;
fn to_raw(&self) -> sys::v2::cudaDataType;
fn to_raw(&self) -> cublas_sys::cudaDataType;
}

impl BlasDatatype for f32 {
type FloatTy = f32;
fn to_raw(&self) -> sys::v2::cudaDataType {
sys::v2::cudaDataType::CUDA_R_32F
fn to_raw(&self) -> cublas_sys::cudaDataType {
cublas_sys::cudaDataType::CUDA_R_32F
}
}

impl BlasDatatype for f64 {
type FloatTy = f64;
fn to_raw(&self) -> sys::v2::cudaDataType {
sys::v2::cudaDataType::CUDA_R_64F
fn to_raw(&self) -> cublas_sys::cudaDataType {
cublas_sys::cudaDataType::CUDA_R_64F
}
}

impl BlasDatatype for Complex32 {
type FloatTy = f32;
fn to_raw(&self) -> sys::v2::cudaDataType {
sys::v2::cudaDataType::CUDA_C_32F
fn to_raw(&self) -> cublas_sys::cudaDataType {
cublas_sys::cudaDataType::CUDA_C_32F
}
}

impl BlasDatatype for Complex64 {
type FloatTy = f64;
fn to_raw(&self) -> sys::v2::cudaDataType {
sys::v2::cudaDataType::CUDA_C_64F
fn to_raw(&self) -> cublas_sys::cudaDataType {
cublas_sys::cudaDataType::CUDA_C_64F
}
}

Expand Down Expand Up @@ -106,11 +106,11 @@ pub enum MatrixOp {

impl MatrixOp {
/// Returns the corresponding `cublasOperation_t` for this operation.
pub fn to_raw(self) -> sys::v2::cublasOperation_t {
pub fn to_raw(self) -> cublas_sys::cublasOperation_t {
match self {
MatrixOp::None => sys::v2::cublasOperation_t::CUBLAS_OP_N,
MatrixOp::Transpose => sys::v2::cublasOperation_t::CUBLAS_OP_T,
MatrixOp::ConjugateTranspose => sys::v2::cublasOperation_t::CUBLAS_OP_C,
MatrixOp::None => cublas_sys::cublasOperation_t::CUBLAS_OP_N,
MatrixOp::Transpose => cublas_sys::cublasOperation_t::CUBLAS_OP_T,
MatrixOp::ConjugateTranspose => cublas_sys::cublasOperation_t::CUBLAS_OP_C,
}
}
}
7 changes: 5 additions & 2 deletions crates/blastoff/src/raw/level1.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::{sys::v2::*, BlasDatatype};
use num_complex::{Complex32, Complex64};
use std::os::raw::c_int;

use cust_raw::cublas_sys::*;
use num_complex::{Complex32, Complex64};

use crate::BlasDatatype;

pub trait Level1: BlasDatatype {
unsafe fn amax(
handle: cublasHandle_t,
Expand Down
7 changes: 5 additions & 2 deletions crates/blastoff/src/raw/level3.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::{sys::v2::*, GemmDatatype};
use num_complex::{Complex32, Complex64};
use std::os::raw::c_int;

use cust_raw::cublas_sys::*;
use num_complex::{Complex32, Complex64};

use crate::GemmDatatype;

pub trait GemmOps: GemmDatatype {
unsafe fn gemm(
handle: cublasHandle_t,
Expand Down
Loading
Loading