Skip to content

First pass #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 109 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
b56e90c
wip
anderslanglands Oct 30, 2021
ff04700
bootstrap enough optix to get ex02 working
anderslanglands Oct 30, 2021
8609413
Add example 03
anderslanglands Oct 31, 2021
3856ed5
add logging callback
anderslanglands Oct 31, 2021
c4346ef
remove ustr
anderslanglands Oct 31, 2021
b436648
Manually create OptixShaderBindingTable
anderslanglands Oct 31, 2021
e21e5a7
Switch Module and Pipeline methods to their structs
anderslanglands Oct 31, 2021
73ad982
Switch Module, Pipeline, ProgramGroup methods to their structs
anderslanglands Oct 31, 2021
aeb670b
Merge branch 'optix' of github.com:RDambrosio016/Rust-CUDA into optix
anderslanglands Oct 31, 2021
0e24288
Refactor: remove dead imports
RDambrosio016 Oct 31, 2021
f76b1d9
derive DeviceCopy
anderslanglands Nov 1, 2021
deea117
typo
anderslanglands Nov 1, 2021
f945be3
Better error message
anderslanglands Nov 1, 2021
6b0011f
Move destroy to Drop impl
anderslanglands Nov 1, 2021
7c9c686
typo
anderslanglands Nov 1, 2021
51cbc04
rename OptixContext to DeviceContext
anderslanglands Nov 1, 2021
234aa93
Make launch params variable name optional
anderslanglands Nov 1, 2021
b1f88be
Remove Clone from Module and ProgramGroup
anderslanglands Nov 1, 2021
d16856b
Make log callback safe
anderslanglands Nov 1, 2021
82720bd
Merge branch 'optix' of github.com:RDambrosio016/Rust-CUDA into optix
anderslanglands Nov 1, 2021
e0e0092
add wip glam support
anderslanglands Nov 2, 2021
87208bf
dont panic in drop
anderslanglands Nov 2, 2021
a05270e
Rework DevicePointer on top of CUdeviceptr
anderslanglands Nov 2, 2021
970af09
wip
anderslanglands Oct 30, 2021
2527ada
bootstrap enough optix to get ex02 working
anderslanglands Oct 30, 2021
ad863c0
Add example 03
anderslanglands Oct 31, 2021
ee5bb1b
add logging callback
anderslanglands Oct 31, 2021
a50e587
remove ustr
anderslanglands Oct 31, 2021
2a5b787
Manually create OptixShaderBindingTable
anderslanglands Oct 31, 2021
2205484
Switch Module and Pipeline methods to their structs
anderslanglands Oct 31, 2021
9eb6a08
Switch Module, Pipeline, ProgramGroup methods to their structs
anderslanglands Oct 31, 2021
1b8059c
Refactor: remove dead imports
RDambrosio016 Oct 31, 2021
28c4b83
derive DeviceCopy
anderslanglands Nov 1, 2021
ba9f05b
typo
anderslanglands Nov 1, 2021
9be649d
Better error message
anderslanglands Nov 1, 2021
35b935b
Move destroy to Drop impl
anderslanglands Nov 1, 2021
ecff765
typo
anderslanglands Nov 1, 2021
71f8f37
rename OptixContext to DeviceContext
anderslanglands Nov 1, 2021
73cd601
Make launch params variable name optional
anderslanglands Nov 1, 2021
e396eff
Remove Clone from Module and ProgramGroup
anderslanglands Nov 1, 2021
aee77ab
Make log callback safe
anderslanglands Nov 1, 2021
d1d06f2
add wip glam support
anderslanglands Nov 2, 2021
9f77889
dont panic in drop
anderslanglands Nov 2, 2021
c307b53
wip accel support
anderslanglands Nov 2, 2021
2a7d98d
Add accel wip
anderslanglands Nov 3, 2021
ae726df
fix merge
anderslanglands Nov 3, 2021
4607ca3
Rework acceleration structure stuff
anderslanglands Nov 3, 2021
76dc524
add lifetime bound on Instance to referenced Accel
anderslanglands Nov 3, 2021
6cf07ff
Have DeviceCopy impl for lifetime markers use null type
anderslanglands Nov 3, 2021
e163a25
Add unsaafe from_handle ctor
anderslanglands Nov 3, 2021
0c13305
Add update for DynamicAccel
anderslanglands Nov 3, 2021
cbd06a8
Hash build inputs to ensure update is sound
anderslanglands Nov 3, 2021
afd5ef5
Add relocation info
anderslanglands Nov 4, 2021
e4feec4
Add remaning DeviceContext methods
anderslanglands Nov 4, 2021
da817c1
Correct docstrings
anderslanglands Nov 4, 2021
8d1839f
Add doc comments
anderslanglands Nov 4, 2021
718bec7
Add a prelude
anderslanglands Nov 4, 2021
4d36fca
Own the geometry flags array
anderslanglands Nov 4, 2021
56ce8f0
Add prelude
anderslanglands Nov 4, 2021
100759f
own the geometry flags array and add support for pre_transform
anderslanglands Nov 4, 2021
6052229
Fill out context and add some module docs
anderslanglands Nov 4, 2021
0d7a78f
Add some module docs
anderslanglands Nov 4, 2021
8663bed
Update to latest library changes
anderslanglands Nov 4, 2021
a1e6598
Add more docs
anderslanglands Nov 4, 2021
b3af8ae
Remove mut requirement for getting pointer
anderslanglands Nov 6, 2021
6f2abe7
Add a simple memcpy_htod wrapper
anderslanglands Nov 6, 2021
7d45757
Add back pointer offset methods
anderslanglands Nov 6, 2021
b3425b3
Big structure reorg and documentation push
anderslanglands Nov 6, 2021
25f7311
Wrap SBT properly
anderslanglands Nov 7, 2021
6b002ec
Rename transform types
anderslanglands Nov 7, 2021
31e8e76
Simplify AccelBuildOptions creation
anderslanglands Nov 7, 2021
66f6aad
Hide programming guide in details tags
anderslanglands Nov 7, 2021
7ce2381
Adapt to latest changes
anderslanglands Nov 7, 2021
23fbe9a
Fix toolchain version
anderslanglands Nov 7, 2021
67106f4
Fix name of DeviceContext
anderslanglands Nov 7, 2021
be5fd90
first optix rust test
anderslanglands Nov 7, 2021
0ef56c7
Set ALLOW_COMPACTION in build options
anderslanglands Nov 8, 2021
f1f29cb
Use find_cuda_helper to get cuda path
anderslanglands Nov 8, 2021
b70ef0a
Handle differering enum representation on windows and linux
anderslanglands Nov 8, 2021
bb9c61e
Add DeviceVariable
anderslanglands Nov 11, 2021
4b855ee
Add DeviceMemory trait
anderslanglands Nov 11, 2021
fd68e3a
Add mem_get_info
anderslanglands Nov 11, 2021
8edb09e
Add external memory
anderslanglands Nov 11, 2021
b456b60
Add a few more types to prelude
anderslanglands Nov 11, 2021
3e7e7f4
Add more types
anderslanglands Nov 11, 2021
c31850a
Rework on top of new DeviceVariable
anderslanglands Nov 11, 2021
cb89463
first optix rust test
anderslanglands Nov 7, 2021
69f2cea
tweak build
anderslanglands Nov 7, 2021
c313f42
merge
anderslanglands Nov 11, 2021
c6718c0
update to latest optix changes
anderslanglands Nov 11, 2021
32f477c
Merge branch 'master' into optix
anderslanglands Nov 11, 2021
844726e
Merge branch 'optix' into optix_rust
anderslanglands Nov 11, 2021
0772dd3
Split DeviceCopy into cust_core
anderslanglands Nov 12, 2021
c318a0a
update to latest optix changes
anderslanglands Nov 12, 2021
d270e80
trying to get print working
anderslanglands Nov 12, 2021
7a7a2c2
tweak test kernel
anderslanglands Nov 12, 2021
131c843
stop llvm optimizing out LaunchParams
anderslanglands Nov 12, 2021
acd03a1
merge optix_rust into optix
anderslanglands Nov 13, 2021
9f488ea
Feat: first commit to update PR to master
RDambrosio016 Dec 12, 2021
eae4261
Chore: update cargo.toml dep versions
RDambrosio016 Dec 12, 2021
799d79b
Feat: second pass for fixing conflicts
RDambrosio016 Dec 13, 2021
e858fdc
Feat: delete as_ptr and as_mut_ptr on DeviceSlice
RDambrosio016 Dec 13, 2021
d0cdcbd
Revert "Feat: delete as_ptr and as_mut_ptr on DeviceSlice"
RDambrosio016 Dec 13, 2021
c48d8ee
Feat: experiment with deleting as_ptr and as_mut_ptr
RDambrosio016 Dec 15, 2021
d32555c
Merge branch 'master' into optix
RDambrosio016 Jan 21, 2022
766cb60
Fix issues and warnings
RDambrosio016 Jan 21, 2022
d9ef5f5
Chore: run formatting
RDambrosio016 Jan 21, 2022
04c2059
Chore: exclude examples from building in CI
RDambrosio016 Jan 21, 2022
741264a
Feat: update changelog with changes, misc changes before merge
RDambrosio016 Jan 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
run: cargo fmt --all -- --check

- name: Build
run: cargo build --workspace --exclude "optix" --exclude "optix_sys" --exclude "path_tracer" --exclude "denoiser" --exclude "add"
run: cargo build --workspace --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"

# Don't currently test because many tests rely on the system having a CUDA GPU
# - name: Test
Expand All @@ -69,9 +69,9 @@ jobs:
if: contains(matrix.os, 'ubuntu')
env:
RUSTFLAGS: -Dwarnings
run: cargo clippy --workspace --exclude "optix" --exclude "optix_sys" --exclude "path_tracer" --exclude "denoiser" --exclude "add"
run: cargo clippy --workspace --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"

- name: Check documentation
env:
RUSTDOCFLAGS: -Dwarnings
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix" --exclude "optix_sys" --exclude "path_tracer" --exclude "denoiser" --exclude "add"
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[workspace]
members = [
"crates/*",

"crates/optix/examples/ex*",
"crates/optix/examples/ex*/device",
"xtask",

"examples/optix/*",
Expand All @@ -10,5 +11,9 @@ members = [

]

exclude = [
"crates/optix/examples/common"
]

[profile.dev.package.rustc_codegen_nvvm]
opt-level = 3
2 changes: 1 addition & 1 deletion crates/blastoff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA"
[dependencies]
bitflags = "1.3.2"
cublas_sys = { version = "0.1", path = "../cublas_sys" }
cust = { version = "0.2", path = "../cust", features = ["num-complex"] }
cust = { version = "0.2", path = "../cust", features = ["impl_num_complex"] }
num-complex = "0.4.0"

[package.metadata.docs.rs]
Expand Down
48 changes: 24 additions & 24 deletions crates/blastoff/src/level1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ impl CublasContext {
Ok(T::amin(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
stride.unwrap_or(1) as i32,
result.as_device_ptr().as_raw_mut(),
result.as_device_ptr().as_mut_ptr(),
)
.to_result()?)
})
Expand Down Expand Up @@ -108,9 +108,9 @@ impl CublasContext {
Ok(T::amax(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
stride.unwrap_or(1) as i32,
result.as_device_ptr().as_raw_mut(),
result.as_device_ptr().as_mut_ptr(),
)
.to_result()?)
})
Expand Down Expand Up @@ -172,10 +172,10 @@ impl CublasContext {
Ok(T::axpy(
ctx.raw,
n as i32,
alpha.as_device_ptr().as_raw(),
x.as_device_ptr().as_raw(),
alpha.as_device_ptr().as_ptr(),
x.as_device_ptr().as_ptr(),
x_stride.unwrap_or(1) as i32,
y.as_device_ptr().as_raw_mut(),
y.as_device_ptr().as_mut_ptr(),
y_stride.unwrap_or(1) as i32,
)
.to_result()?)
Expand Down Expand Up @@ -245,9 +245,9 @@ impl CublasContext {
Ok(T::copy(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
x_stride.unwrap_or(1) as i32,
y.as_device_ptr().as_raw_mut(),
y.as_device_ptr().as_mut_ptr(),
y_stride.unwrap_or(1) as i32,
)
.to_result()?)
Expand Down Expand Up @@ -314,11 +314,11 @@ impl CublasContext {
Ok(T::dot(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
x_stride.unwrap_or(1) as i32,
y.as_device_ptr().as_raw(),
y.as_device_ptr().as_ptr(),
y_stride.unwrap_or(1) as i32,
result.as_device_ptr().as_raw_mut(),
result.as_device_ptr().as_mut_ptr(),
)
.to_result()?)
})
Expand Down Expand Up @@ -390,11 +390,11 @@ impl CublasContext {
Ok(T::dotu(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
x_stride.unwrap_or(1) as i32,
y.as_device_ptr().as_raw(),
y.as_device_ptr().as_ptr(),
y_stride.unwrap_or(1) as i32,
result.as_device_ptr().as_raw_mut(),
result.as_device_ptr().as_mut_ptr(),
)
.to_result()?)
})
Expand Down Expand Up @@ -438,11 +438,11 @@ impl CublasContext {
Ok(T::dotc(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
x_stride.unwrap_or(1) as i32,
y.as_device_ptr().as_raw(),
y.as_device_ptr().as_ptr(),
y_stride.unwrap_or(1) as i32,
result.as_device_ptr().as_raw_mut(),
result.as_device_ptr().as_mut_ptr(),
)
.to_result()?)
})
Expand Down Expand Up @@ -483,9 +483,9 @@ impl CublasContext {
Ok(T::nrm2(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw(),
x.as_device_ptr().as_ptr(),
x_stride.unwrap_or(1) as i32,
result.as_device_ptr().as_raw_mut(),
result.as_device_ptr().as_mut_ptr(),
)
.to_result()?)
})
Expand Down Expand Up @@ -559,12 +559,12 @@ impl CublasContext {
Ok(T::rot(
ctx.raw,
n as i32,
x.as_device_ptr().as_raw_mut(),
x.as_device_ptr().as_mut_ptr(),
x_stride.unwrap_or(1) as i32,
y.as_device_ptr().as_raw_mut(),
y.as_device_ptr().as_mut_ptr(),
y_stride.unwrap_or(1) as i32,
c.as_device_ptr().as_raw(),
s.as_device_ptr().as_raw(),
c.as_device_ptr().as_ptr(),
s.as_device_ptr().as_ptr(),
)
.to_result()?)
})
Expand Down
70 changes: 35 additions & 35 deletions crates/cudnn/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ impl CudnnContext {
let x_data = x.data().as_device_ptr().as_raw();

let y_desc = y.descriptor();
let y_data = y.data().as_device_ptr().as_raw_mut();
let y_data = y.data().as_device_ptr().as_ptr();

let reserve_space_ptr = reserve_space.as_device_ptr().as_raw_mut();
let reserve_space_ptr = reserve_space.as_device_ptr().as_ptr();

unsafe {
sys::cudnnDropoutForward(
Expand Down Expand Up @@ -454,9 +454,9 @@ impl CudnnContext {
let dy_data = dy.data().as_device_ptr().as_raw();

let dx_desc = dx.descriptor();
let dx_data = dx.data().as_device_ptr().as_raw_mut();
let dx_data = dx.data().as_device_ptr().as_ptr();

let reserve_space_ptr = reserve_space.as_device_ptr().as_raw_mut();
let reserve_space_ptr = reserve_space.as_device_ptr().as_ptr();

unsafe {
sys::cudnnDropoutBackward(
Expand Down Expand Up @@ -528,7 +528,7 @@ impl CudnnContext {
raw,
self.raw,
dropout,
states.as_device_ptr().as_raw_mut() as *mut std::ffi::c_void,
states.as_device_ptr().as_ptr() as *mut std::ffi::c_void,
states.len(),
seed,
)
Expand Down Expand Up @@ -1185,14 +1185,14 @@ impl CudnnContext {
let w_data = w.data().as_device_ptr().as_raw();
let w_desc = w.descriptor();

let y_data = y.data().as_device_ptr().as_raw_mut();
let y_data = y.data().as_device_ptr().as_ptr();
let y_desc = y.descriptor();

// If the _ size is 0 then the algorithm can work in-place and cuDNN expects a null
// pointer.
let (work_space_ptr, work_space_size): (*mut u8, usize) = {
work_space.map_or((std::ptr::null_mut(), 0), |work_space| {
(work_space.as_device_ptr().as_raw_mut(), work_space.len())
(work_space.as_device_ptr().as_mut_ptr(), work_space.len())
})
};

Expand Down Expand Up @@ -1287,12 +1287,12 @@ impl CudnnContext {
let dy_data = dy.data().as_device_ptr().as_raw();
let dy_desc = dy.descriptor();

let dx_data = dx.data().as_device_ptr().as_raw_mut();
let dx_data = dx.data().as_device_ptr().as_ptr();
let dx_desc = dx.descriptor();

let (work_space_ptr, work_space_size): (*mut u8, usize) = {
work_space.map_or((std::ptr::null_mut(), 0), |work_space| {
(work_space.as_device_ptr().as_raw_mut(), work_space.len())
(work_space.as_device_ptr().as_mut_ptr(), work_space.len())
})
};

Expand Down Expand Up @@ -1388,12 +1388,12 @@ impl CudnnContext {
let dy_data = dy.data().as_device_ptr().as_raw();
let dy_desc = dy.descriptor();

let dw_data = dw.data().as_device_ptr().as_raw_mut();
let dw_data = dw.data().as_device_ptr().as_ptr();
let dw_desc = dw.descriptor();

let (work_space_ptr, work_space_size): (*mut u8, usize) = {
work_space.map_or((std::ptr::null_mut(), 0), |work_space| {
(work_space.as_device_ptr().as_raw_mut(), work_space.len())
(work_space.as_device_ptr().as_mut_ptr(), work_space.len())
})
};

Expand Down Expand Up @@ -1615,28 +1615,28 @@ impl CudnnContext {
L: RnnDataLayout,
NCHW: SupportedType<T1>,
{
let device_sequence_lengths_ptr = device_seq_lengths.as_device_ptr().as_raw();
let device_sequence_lengths_ptr = device_seq_lengths.as_device_ptr().as_ptr();

let x_ptr = x.as_device_ptr().as_raw();
let y_ptr = y.as_device_ptr().as_raw_mut();
let y_ptr = y.as_device_ptr().as_ptr();

let hx_ptr = hx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_raw());
let hx_ptr = hx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_ptr());
let hy_ptr = hy.map_or(std::ptr::null_mut(), |buff| {
buff.as_device_ptr().as_raw_mut()
buff.as_device_ptr().as_mut_ptr()
});

let c_desc = c_desc.map_or(std::ptr::null_mut(), |desc| desc.raw);

let cx_ptr = cx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_raw());
let cx_ptr = cx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_ptr());
let cy_ptr = cy.map_or(std::ptr::null_mut(), |buff| {
buff.as_device_ptr().as_raw_mut()
buff.as_device_ptr().as_mut_ptr()
});

let weight_space_ptr = weight_space.as_device_ptr().as_raw_mut();
let work_space_ptr = work_space.as_device_ptr().as_raw_mut();
let weight_space_ptr = weight_space.as_device_ptr().as_ptr();
let work_space_ptr = work_space.as_device_ptr().as_ptr();
let (reserve_space_ptr, reserve_space_size) = reserve_space
.map_or((std::ptr::null_mut(), 0), |buff| {
(buff.as_device_ptr().as_raw_mut(), buff.len())
(buff.as_device_ptr().as_mut_ptr(), buff.len())
});

unsafe {
Expand Down Expand Up @@ -1814,32 +1814,32 @@ impl CudnnContext {
L: RnnDataLayout,
NCHW: SupportedType<T1>,
{
let device_sequence_lengths_ptr = device_seq_lengths.as_device_ptr().as_raw();
let device_sequence_lengths_ptr = device_seq_lengths.as_device_ptr().as_ptr();

let y_ptr = y.as_device_ptr().as_raw();
let dy_ptr = dy.as_device_ptr().as_raw();

let dx_ptr = dx.as_device_ptr().as_raw_mut();
let dx_ptr = dx.as_device_ptr().as_ptr();

let h_desc = h_desc.map_or(std::ptr::null_mut(), |desc| desc.raw);

let hx_ptr = hx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_raw());
let dhy_ptr = dhy.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_raw());
let hx_ptr = hx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_ptr());
let dhy_ptr = dhy.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_ptr());
let dhx_ptr = dhx.map_or(std::ptr::null_mut(), |buff| {
buff.as_device_ptr().as_raw_mut()
buff.as_device_ptr().as_mut_ptr()
});

let c_desc = c_desc.map_or(std::ptr::null_mut(), |desc| desc.raw);

let cx_ptr = cx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_raw());
let dcy_ptr = dcy.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_raw());
let cx_ptr = cx.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_ptr());
let dcy_ptr = dcy.map_or(std::ptr::null(), |buff| buff.as_device_ptr().as_mut_ptr());
let dcx_ptr = dcx.map_or(std::ptr::null_mut(), |buff| {
buff.as_device_ptr().as_raw_mut()
buff.as_device_ptr().as_mut_ptr()
});

let weight_space_ptr = weight_space.as_device_ptr().as_raw_mut();
let work_space_ptr = work_space.as_device_ptr().as_raw_mut();
let reserve_space_ptr = reserve_space.as_device_ptr().as_raw_mut();
let weight_space_ptr = weight_space.as_device_ptr().as_ptr();
let work_space_ptr = work_space.as_device_ptr().as_ptr();
let reserve_space_ptr = reserve_space.as_device_ptr().as_ptr();

unsafe {
sys::cudnnRNNBackwardData_v8(
Expand Down Expand Up @@ -1947,15 +1947,15 @@ impl CudnnContext {
L: RnnDataLayout,
NCHW: SupportedType<T1>,
{
let device_sequence_lengths_ptr = device_seq_lengths.as_device_ptr().as_raw();
let device_sequence_lengths_ptr = device_seq_lengths.as_device_ptr().as_mut_ptr();

let x_ptr = x.as_device_ptr().as_raw();
let hx_ptr = x.as_device_ptr().as_raw();
let y_ptr = y.as_device_ptr().as_raw();

let dweight_space_ptr = dweight_space.as_device_ptr().as_raw_mut();
let work_space_ptr = work_space.as_device_ptr().as_raw_mut();
let reserve_space_ptr = reserve_space.as_device_ptr().as_raw_mut();
let dweight_space_ptr = dweight_space.as_device_ptr().as_mut_ptr();
let work_space_ptr = work_space.as_device_ptr().as_mut_ptr();
let reserve_space_ptr = reserve_space.as_device_ptr().as_mut_ptr();

unsafe {
sys::cudnnRNNBackwardWeights_v8(
Expand Down
Loading