Skip to content

Commit 08c3ffa

Browse files
committed
Implement into parallel iterator for array, rcarray
1 parent 61ff5da commit 08c3ffa

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/iterators/par.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,50 @@ macro_rules! par_iter_wrapper {
117117
par_iter_wrapper!(AxisIter, [Sync]);
118118
par_iter_wrapper!(AxisIterMut, [Send + Sync]);
119119

120+
impl<'a, A, D> IntoParallelIterator for &'a Array<A, D>
121+
where D: Dimension,
122+
A: Sync
123+
{
124+
type Item = &'a A;
125+
type Iter = Parallel<ArrayView<'a, A, D>>;
126+
fn into_par_iter(self) -> Self::Iter {
127+
self.view().into_par_iter()
128+
}
129+
}
130+
131+
impl<'a, A, D> IntoParallelIterator for &'a RcArray<A, D>
132+
where D: Dimension,
133+
A: Sync
134+
{
135+
type Item = &'a A;
136+
type Iter = Parallel<ArrayView<'a, A, D>>;
137+
fn into_par_iter(self) -> Self::Iter {
138+
self.view().into_par_iter()
139+
}
140+
}
141+
142+
impl<'a, A, D> IntoParallelIterator for &'a mut Array<A, D>
143+
where D: Dimension,
144+
A: Sync + Send
145+
{
146+
type Item = &'a mut A;
147+
type Iter = Parallel<ArrayViewMut<'a, A, D>>;
148+
fn into_par_iter(self) -> Self::Iter {
149+
self.view_mut().into_par_iter()
150+
}
151+
}
152+
153+
impl<'a, A, D> IntoParallelIterator for &'a mut RcArray<A, D>
154+
where D: Dimension,
155+
A: Sync + Send + Clone,
156+
{
157+
type Item = &'a mut A;
158+
type Iter = Parallel<ArrayViewMut<'a, A, D>>;
159+
fn into_par_iter(self) -> Self::Iter {
160+
self.view_mut().into_par_iter()
161+
}
162+
}
163+
120164
macro_rules! par_iter_view_wrapper {
121165
// thread_bounds are either Sync or Send + Sync
122166
($view_name:ident, [$($thread_bounds:tt)*]) => {

tests/rayon.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ fn test_regular_iter() {
3737
for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() {
3838
v.fill(i as _);
3939
}
40-
let s = a.view().into_par_iter().map(|&x| x).sum();
40+
let s = a.par_iter().map(|&x| x).sum();
4141
println!("{:?}", a.slice(s![..10, ..5]));
4242
assert_eq!(s, a.scalar_sum());
4343
}

0 commit comments

Comments
 (0)