Skip to content

Commit 7eb8063

Browse files
authored
Merge pull request #17 from rust-dd/feat/cuda-support
feat: support generating fgn on cuda
2 parents 5b36c87 + 93130b4 commit 7eb8063

File tree

14 files changed

+256
-34
lines changed

14 files changed

+256
-34
lines changed

.devcontainer/devcontainer.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"image": "rapidsai/devcontainers:23.12-cpp-llvm16-rust-cuda12.3-ubuntu22.04",
3+
"hostRequirements": { "gpu": true },
4+
"workspaceFolder": "/home/coder/${localWorkspaceFolderBasename}",
5+
"workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/${localWorkspaceFolderBasename},type=bind"
6+
}

.zed/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"rust-analyzer": {
88
"initialization_options": {
99
"cargo": {
10-
"features": [""]
10+
"features": ["cuda"]
1111
}
1212
}
1313
}

Cargo.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ gauss-quad = "0.2.1"
2525
impl-new-derive = "0.1.2"
2626
implied-vol = "1.0.0"
2727
indicatif = "0.17.8"
28-
itransformer = "1.0.1"
28+
# itransformer = "1.0.1"
2929
kendalls = "0.2.2"
3030
levenberg-marquardt = "0.14.0"
3131
linreg = "0.2.0"
@@ -67,12 +67,23 @@ yahoo_finance_api = { version = "2.3.0", optional = true }
6767
[dev-dependencies]
6868

6969
[features]
70-
default = ["jemalloc"]
70+
cuda = ["dep:cudarc", "dep:libloading"]
71+
default = ["cuda"]
7172
jemalloc = ["dep:tikv-jemallocator"]
7273
malliavin = []
7374
mimalloc = ["dep:mimalloc"]
7475
yahoo = ["dep:time", "dep:yahoo_finance_api"]
7576

77+
[target.'cfg(target_os = "macos")'.dependencies]
78+
ndarray-linalg = { version = "0.17.0", features = ["openblas-static"] }
79+
80+
[target.'cfg(not(target_os = "macos"))'.dependencies]
81+
cudarc = { version = "0.13.9", optional = true, features = [
82+
"cuda-12080",
83+
"cuda-version-from-build-system",
84+
] }
85+
libloading = { version = "0.8.6", optional = true }
86+
7687
[lib]
7788
name = "stochastic_rs"
7889
crate-type = ["cdylib", "lib"]

src/ai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use candle_core::Tensor;
2-
pub use itransformer::ITransformer;
2+
// pub use itransformer::ITransformer;
33

44
pub mod fou;
55
pub mod utils;

src/main.rs

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
use ndarray::Array1;
12
use prettytable::{format, row, Cell, Row, Table};
3+
use stochastic_rs::plot_1d;
4+
use stochastic_rs::stochastic::noise::fgn::FGN;
5+
use stochastic_rs::stochastic::Sampling;
26
use std::error::Error;
37
use std::fs::File;
48
use std::io::{BufRead, BufReader};
@@ -15,36 +19,25 @@ use stochastic_rs::stats::fou_estimator::{
1519
// use your_crate::{FOUParameterEstimationV1, FOUParameterEstimationV2, FilterType};
1620
const N: usize = 10000;
1721
fn main() -> Result<(), Box<dyn Error>> {
18-
for _ in 0..10 {
19-
let mut table = Table::new();
20-
21-
table.add_row(Row::new(vec![
22-
Cell::new("Test Case"),
23-
Cell::new("Elapsed Time (ms)"),
24-
]));
25-
26-
let start = Instant::now();
27-
let fbm = FGN::new(0.7, N, Some(1.0), None);
28-
let _ = fbm.sample();
29-
let duration = start.elapsed();
30-
table.add_row(Row::new(vec![
31-
Cell::new("Single Sample"),
32-
Cell::new(&format!("{:.2?}", duration.as_millis())),
33-
]));
22+
let fbm = FGN::new(0.7, 10_000, Some(1.0), Some(10000));
23+
let fgn = fbm.sample_cuda().unwrap();
24+
let fgn = fgn.row(0);
25+
plot_1d!(fgn, "Fractional Brownian Motion (H = 0.7)");
26+
let mut path = Array1::<f64>::zeros(500);
27+
for i in 1..500 {
28+
path[i] += path[i-1] + fgn[i];
29+
}
30+
plot_1d!(path, "Fractional Brownian Motion (H = 0.7)");
3431

35-
let start = Instant::now();
36-
let fbm = FGN::new(0.7, N, Some(1.0), None);
37-
for _ in 0..N {
38-
let _ = fbm.sample();
39-
}
40-
let duration = start.elapsed();
41-
table.add_row(Row::new(vec![
42-
Cell::new("Repeated Samples"),
43-
Cell::new(&format!("{:.2?}", duration.as_millis())),
44-
]));
32+
let start = std::time::Instant::now();
33+
let _ = fbm.sample_cuda();
34+
let end = start.elapsed().as_millis();
35+
println!("20000 fgn generated on cuda in: {end}");
4536

46-
table.printstd();
47-
}
37+
let start = std::time::Instant::now();
38+
let _ = fbm.sample_par();
39+
let end = start.elapsed().as_millis();
40+
println!("20000 fgn generated on cuda in: {end}");
4841
// File paths
4942
// let paths = vec![
5043
// "./test/kecskekut_original.txt",

src/stochastic.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod noise;
2626
pub mod process;
2727
pub mod volatility;
2828

29+
use std::error::Error;
2930
use std::sync::{Arc, Mutex};
3031

3132
use ndarray::parallel::prelude::*;
@@ -42,6 +43,13 @@ pub trait Sampling<T: Clone + Send + Sync + Zero>: Send + Sync {
4243
/// Sample the process
4344
fn sample(&self) -> Array1<T>;
4445

46+
/// Sample the process with CUDA support
47+
#[cfg(not(target_os = "macos"))]
48+
#[cfg(feature = "cuda")]
49+
fn sample_cuda(&self) -> Result<Array2<T>, Box<dyn Error>> {
50+
unimplemented!()
51+
}
52+
4553
/// Parallel sampling
4654
fn sample_par(&self) -> Array2<T> {
4755
if self.m().is_none() {

src/stochastic/cuda/fgn.cu

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <stdio.h>
2+
#include <cuda_runtime.h>
3+
#include <curand_kernel.h>
4+
#include <cufft.h>
5+
#include <cuComplex.h>
6+
#include <math.h>
7+
8+
#ifdef _WIN32
9+
#define EXPORT __declspec(dllexport)
10+
#else
11+
#define EXPORT
12+
#endif
13+
14+
__global__ void fill_random_with_eigs(
15+
cuComplex* d_data,
16+
const cuComplex* d_sqrt_eigs,
17+
int traj_size,
18+
int m,
19+
unsigned long seed)
20+
{
21+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
22+
if (tid >= m * traj_size) return;
23+
int traj_id = tid / traj_size;
24+
int idx = tid % traj_size;
25+
curandState state;
26+
curand_init(seed + traj_id, idx, 0, &state);
27+
float re = curand_normal(&state);
28+
float im = curand_normal(&state);
29+
cuComplex noise = make_cuComplex(re, im);
30+
d_data[tid] = cuCmulf(noise, d_sqrt_eigs[idx]);
31+
}
32+
33+
__global__ void scale_and_copy_to_output(
34+
const cuComplex* d_data,
35+
float* d_output,
36+
int n,
37+
int m,
38+
int offset,
39+
float hurst,
40+
float t)
41+
{
42+
int out_size = n - offset;
43+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
44+
if (tid >= m * out_size) return;
45+
int traj_id = tid / out_size;
46+
int idx = tid % out_size;
47+
int data_idx = traj_id * (2 * n) + (idx + 1);
48+
float scale = powf((float)n, -hurst) * powf(t, hurst);
49+
d_output[tid] = d_data[data_idx].x * scale;
50+
}
51+
52+
extern "C" EXPORT void fgn_kernel(
53+
const cuComplex* d_sqrt_eigs,
54+
float* d_output,
55+
int n,
56+
int m,
57+
int offset,
58+
float hurst,
59+
float t,
60+
unsigned long seed)
61+
{
62+
int traj_size = 2 * n;
63+
cuComplex* d_data = nullptr;
64+
cudaMalloc(&d_data, (size_t)m * traj_size * sizeof(cuComplex));
65+
{
66+
int totalThreads = m * traj_size;
67+
int blockSize = 512;
68+
int gridSize = (totalThreads + blockSize - 1) / blockSize;
69+
fill_random_with_eigs<<<gridSize, blockSize>>>(d_data, d_sqrt_eigs, traj_size, m, seed);
70+
cudaDeviceSynchronize();
71+
}
72+
{
73+
cufftHandle plan;
74+
cufftPlan1d(&plan, traj_size, CUFFT_C2C, m);
75+
cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD);
76+
cudaDeviceSynchronize();
77+
cufftDestroy(plan);
78+
}
79+
{
80+
int out_size = n - offset;
81+
int totalThreads = m * out_size;
82+
int blockSize = 512;
83+
int gridSize = (totalThreads + blockSize - 1) / blockSize;
84+
scale_and_copy_to_output<<<gridSize, blockSize>>>(d_data, d_output, n, m, offset, hurst, t);
85+
cudaDeviceSynchronize();
86+
}
87+
cudaFree(d_data);
88+
}
827 Bytes
Binary file not shown.
1.84 KB
Binary file not shown.
837 KB
Binary file not shown.

0 commit comments

Comments
 (0)