diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs
index 0f15c4b50..882eaaa77 100644
--- a/src/iterators/mod.rs
+++ b/src/iterators/mod.rs
@@ -825,6 +825,19 @@ impl AxisIterCore {
};
(left, right)
}
+
+ /// Does the same thing as `.next()` but also returns the index of the item
+ /// relative to the start of the axis.
+ fn next_with_index(&mut self) -> Option<(usize, *mut A)> {
+ let index = self.index;
+ self.next().map(|ptr| (index, ptr))
+ }
+
+ /// Does the same thing as `.next_back()` but also returns the index of the
+ /// item relative to the start of the axis.
+ fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> {
+ self.next_back().map(|ptr| (self.end, ptr))
+ }
}
impl Iterator for AxisIterCore
@@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information.
pub struct AxisChunksIter<'a, A, D> {
iter: AxisIterCore,
- n_whole_chunks: usize,
- /// Dimension of the last (and possibly uneven) chunk
- last_dim: D,
+ /// Index of the partial chunk (the chunk smaller than the specified chunk
+ /// size due to the axis length not being evenly divisible). If the axis
+ /// length is evenly divisible by the chunk size, this index is larger than
+ /// the maximum valid index.
+ partial_chunk_index: usize,
+ /// Dimension of the partial chunk.
+ partial_chunk_dim: D,
life: PhantomData<&'a A>,
}
@@ -1193,10 +1210,10 @@ clone_bounds!(
AxisChunksIter['a, A, D] {
@copy {
life,
- n_whole_chunks,
+ partial_chunk_index,
}
iter,
- last_dim,
+ partial_chunk_dim,
}
);
@@ -1233,12 +1250,9 @@ fn chunk_iter_parts(
let mut inner_dim = v.dim.clone();
inner_dim[axis] = size;
- let mut last_dim = v.dim;
- last_dim[axis] = if chunk_remainder == 0 {
- size
- } else {
- chunk_remainder
- };
+ let mut partial_chunk_dim = v.dim;
+ partial_chunk_dim[axis] = chunk_remainder;
+ let partial_chunk_index = n_whole_chunks;
let iter = AxisIterCore {
index: 0,
@@ -1249,16 +1263,16 @@ fn chunk_iter_parts(
ptr: v.ptr,
};
- (iter, n_whole_chunks, last_dim)
+ (iter, partial_chunk_index, partial_chunk_dim)
}
impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> {
pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self {
- let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size);
+ let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size);
AxisChunksIter {
iter,
- n_whole_chunks,
- last_dim,
+ partial_chunk_index,
+ partial_chunk_dim,
life: PhantomData,
}
}
@@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl {
where
D: Dimension,
{
- fn get_subview(
- &self,
- iter_item: Option<*mut A>,
- is_uneven: bool,
- ) -> Option<$array<'a, A, D>> {
- iter_item.map(|ptr| {
- if !is_uneven {
- unsafe {
- $array::new_(
- ptr,
- self.iter.inner_dim.clone(),
- self.iter.inner_strides.clone(),
- )
- }
- } else {
- unsafe {
- $array::new_(
- ptr,
- self.last_dim.clone(),
- self.iter.inner_strides.clone(),
- )
- }
+ fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> {
+ if index != self.partial_chunk_index {
+ unsafe {
+ $array::new_(
+ ptr,
+ self.iter.inner_dim.clone(),
+ self.iter.inner_strides.clone(),
+ )
+ }
+ } else {
+ unsafe {
+ $array::new_(
+ ptr,
+ self.partial_chunk_dim.clone(),
+ self.iter.inner_strides.clone(),
+ )
}
- })
+ }
+ }
+
+ /// Splits the iterator at index, yielding two disjoint iterators.
+ ///
+ /// `index` is relative to the current state of the iterator (which is not
+ /// necessarily the start of the axis).
+ ///
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
+ /// length.
+ pub fn split_at(self, index: usize) -> (Self, Self) {
+ let (left, right) = self.iter.split_at(index);
+ (
+ Self {
+ iter: left,
+ partial_chunk_index: self.partial_chunk_index,
+ partial_chunk_dim: self.partial_chunk_dim.clone(),
+ life: self.life,
+ },
+ Self {
+ iter: right,
+ partial_chunk_index: self.partial_chunk_index,
+ partial_chunk_dim: self.partial_chunk_dim,
+ life: self.life,
+ },
+ )
}
}
@@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl {
type Item = $array<'a, A, D>;
fn next(&mut self) -> Option {
- let res = self.iter.next();
- let is_uneven = self.iter.index > self.n_whole_chunks;
- self.get_subview(res, is_uneven)
+ self.iter
+ .next_with_index()
+ .map(|(index, ptr)| self.get_subview(index, ptr))
}
fn size_hint(&self) -> (usize, Option) {
@@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl {
D: Dimension,
{
fn next_back(&mut self) -> Option {
- let is_uneven = self.iter.end > self.n_whole_chunks;
- let res = self.iter.next_back();
- self.get_subview(res, is_uneven)
+ self.iter
+ .next_back_with_index()
+ .map(|(index, ptr)| self.get_subview(index, ptr))
}
}
@@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl {
/// for more information.
pub struct AxisChunksIterMut<'a, A, D> {
iter: AxisIterCore,
- n_whole_chunks: usize,
- last_dim: D,
+ partial_chunk_index: usize,
+ partial_chunk_dim: D,
life: PhantomData<&'a mut A>,
}
impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> {
pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self {
- let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size);
+ let (iter, partial_chunk_index, partial_chunk_dim) =
+ chunk_iter_parts(v.into_view(), axis, size);
AxisChunksIterMut {
iter,
- n_whole_chunks: len,
- last_dim,
+ partial_chunk_index,
+ partial_chunk_dim,
life: PhantomData,
}
}
diff --git a/tests/iterators.rs b/tests/iterators.rs
index 6408b2f8d..325aa9797 100644
--- a/tests/iterators.rs
+++ b/tests/iterators.rs
@@ -13,6 +13,20 @@ use itertools::assert_equal;
use itertools::{enumerate, rev};
use std::iter::FromIterator;
+macro_rules! assert_panics {
+ ($body:expr) => {
+ if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
+ panic!("assertion failed: should_panic; \
+ non-panicking result: {:?}", v);
+ }
+ };
+ ($body:expr, $($arg:tt)*) => {
+ if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
+ panic!($($arg)*);
+ }
+ };
+}
+
#[test]
fn double_ended() {
let a = ArcArray::linspace(0., 7., 8);
@@ -585,6 +599,33 @@ fn axis_chunks_iter_zero_axis_len() {
assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none());
}
+#[test]
+fn axis_chunks_iter_split_at() {
+ let mut a = Array2::::zeros((11, 3));
+ a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i);
+ for source in &[
+ a.slice(s![..0, ..]),
+ a.slice(s![..1, ..]),
+ a.slice(s![..5, ..]),
+ a.slice(s![..10, ..]),
+ a.slice(s![..11, ..]),
+ a.slice(s![.., ..0]),
+ ] {
+ let chunks_iter = source.axis_chunks_iter(Axis(0), 5);
+ let all_chunks: Vec<_> = chunks_iter.clone().collect();
+ let n_chunks = chunks_iter.len();
+ assert_eq!(n_chunks, all_chunks.len());
+ for index in 0..=n_chunks {
+ let (left, right) = chunks_iter.clone().split_at(index);
+ assert_eq!(&all_chunks[..index], &left.collect::>()[..]);
+ assert_eq!(&all_chunks[index..], &right.collect::>()[..]);
+ }
+ assert_panics!({
+ chunks_iter.split_at(n_chunks + 1);
+ });
+ }
+}
+
#[test]
fn axis_chunks_iter_mut() {
let a = ArcArray::from_iter(0..24);