Skip to content

Commit b75aea8

Browse files
committed
Rewrite SIMD implementation to use Iterator
1 parent ea8f0a8 commit b75aea8

File tree

2 files changed

+83
-31
lines changed

2 files changed

+83
-31
lines changed

examples/mandelbrot/src/scalar_par.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use *;
66
/// Complex number
77
#[repr(align(16))]
88
#[derive(Copy, Clone)]
9-
pub struct Complex {
9+
struct Complex {
1010
real: f64,
1111
imag: f64,
1212
}
@@ -26,7 +26,7 @@ impl Complex {
2626
}
2727

2828
/// An iterator yielding the infinite Mandelbrot sequence
29-
pub struct MandelbrotIter {
29+
struct MandelbrotIter {
3030
/// Initial value which generated this sequence
3131
start: Complex,
3232
/// Current iteration value

examples/mandelbrot/src/simd_par.rs

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,96 @@ use packed_simd::*;
55
use rayon::prelude::*;
66
use *;
77

8-
pub type u64s = u64x8;
9-
pub type u32s = u32x8;
10-
pub type f64s = f64x8;
11-
12-
/// This function will operate on N complex numbers at once, where N is the
13-
/// number of lanes in a SIMD vector of doubles.
14-
fn mandelbrot(c_x: f64s, c_y: f64s) -> u32s {
15-
let mut x = c_x;
16-
let mut y = c_y;
8+
type u64s = u64x8;
9+
type u32s = u32x8;
10+
type f64s = f64x8;
11+
type m64s = m64x8;
12+
13+
/// Storage for complex numbers in SIMD format.
14+
/// The real and imaginary parts are kept in separate registers.
15+
#[derive(Copy, Clone)]
16+
struct Complex {
17+
real: f64s,
18+
imag: f64s,
19+
}
1720

18-
let mut count = u64s::splat(0);
21+
impl Complex {
22+
/// Returns a mask describing which members of the Mandelbrot sequence
23+
/// haven't diverged yet
24+
#[inline]
25+
fn undiverged(&self) -> m64s {
26+
let Self { real: x, imag: y } = *self;
1927

20-
for i in 0..ITER_LIMIT {
2128
let xx = x * x;
2229
let yy = y * y;
23-
let xy = x * y;
24-
let new_x = c_x + xx - yy;
25-
let new_y = c_y + xy + xy;
30+
let sum = xx + yy;
2631

27-
let sum = x * x + y * y;
32+
sum.le(f64s::splat(THRESHOLD))
33+
}
34+
}
2835

29-
// Keep track of those lanes which haven't diverged yet. The other ones
30-
// will be masked off.
31-
let undiverged = sum.le(f64s::splat(4.));
36+
/// Mandelbrot sequence iterator using SIMD.
37+
struct MandelbrotIter {
38+
/// Initial value which generated this sequence
39+
start: Complex,
40+
/// Current iteration value
41+
current: Complex,
42+
}
43+
44+
impl MandelbrotIter {
45+
/// Creates a new Mandelbrot sequence iterator for a given starting point
46+
fn new(start: Complex) -> Self {
47+
Self { start, current: start }
48+
}
3249

33-
// Stop the iteration if they all diverged. Note that we don't do this
34-
// check every iteration, since a branch misprediction can hurt more
35-
// than doing some extra calculations.
36-
if i % 5 == 0 && undiverged.none() {
37-
break;
50+
/// Returns the number of iterations it takes for each member of the Mandelbrot
51+
/// sequence to diverge at this point, or `ITER_LIMIT` if they don't diverge.
52+
///
53+
/// This function will operate on N complex numbers at once, where N is the
54+
/// number of lanes in a SIMD vector of doubles.
55+
fn count(mut self) -> u32s {
56+
let mut z = self.start;
57+
let mut count = u64s::splat(0);
58+
for _ in 0..ITER_LIMIT {
59+
// Keep track of those lanes which haven't diverged yet. The other ones
60+
// will be masked off.
61+
let undiverged = z.undiverged();
62+
63+
// Stop the iteration if they all diverged. Note that we don't do this
64+
// check every iteration, since a branch misprediction can hurt more
65+
// than doing some extra calculations.
66+
if undiverged.none() {
67+
break;
68+
}
69+
70+
count += undiverged.select(u64s::splat(1), u64s::splat(0));
71+
72+
z = self.next().unwrap();
3873
}
74+
count.cast()
75+
}
76+
}
3977

40-
count += undiverged.select(u64s::splat(1), u64s::splat(0));
78+
impl Iterator for MandelbrotIter {
79+
type Item = Complex;
4180

42-
x = new_x;
43-
y = new_y;
44-
}
81+
/// Generates the next values in the sequence
82+
#[inline]
83+
fn next(&mut self) -> Option<Complex> {
84+
let Complex { real: c_x, imag: c_y } = self.start;
85+
let Complex { real: x, imag: y } = self.current;
4586

46-
count.cast()
87+
let xx = x * x;
88+
let yy = y * y;
89+
let xy = x * y;
90+
91+
let new_x = c_x + (xx - yy);
92+
let new_y = c_y + (xy + xy);
93+
94+
self.current = Complex { real: new_x, imag: new_y };
95+
96+
Some(self.current)
97+
}
4798
}
4899

49100
pub fn generate(dims: Dimensions, xr: Range, yr: Range) -> Vec<u32> {
@@ -89,7 +140,8 @@ pub fn generate(dims: Dimensions, xr: Range, yr: Range) -> Vec<u32> {
89140
let y = f64s::splat(yr.start + dy * (i as f64));
90141
row.iter_mut().enumerate().for_each(|(j, count)| {
91142
let x = xs[j];
92-
*count = mandelbrot(x, y);
143+
let z = Complex { real: x, imag: y };
144+
*count = MandelbrotIter::new(z).count();
93145
});
94146
});
95147

0 commit comments

Comments
 (0)