diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 0f15c4b50..882eaaa77 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -825,6 +825,19 @@ impl AxisIterCore { }; (left, right) } + + /// Does the same thing as `.next()` but also returns the index of the item + /// relative to the start of the axis. + fn next_with_index(&mut self) -> Option<(usize, *mut A)> { + let index = self.index; + self.next().map(|ptr| (index, ptr)) + } + + /// Does the same thing as `.next_back()` but also returns the index of the + /// item relative to the start of the axis. + fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> { + self.next_back().map(|ptr| (self.end, ptr)) + } } impl Iterator for AxisIterCore @@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { /// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information. pub struct AxisChunksIter<'a, A, D> { iter: AxisIterCore, - n_whole_chunks: usize, - /// Dimension of the last (and possibly uneven) chunk - last_dim: D, + /// Index of the partial chunk (the chunk smaller than the specified chunk + /// size due to the axis length not being evenly divisible). If the axis + /// length is evenly divisible by the chunk size, this index is larger than + /// the maximum valid index. + partial_chunk_index: usize, + /// Dimension of the partial chunk. + partial_chunk_dim: D, life: PhantomData<&'a A>, } @@ -1193,10 +1210,10 @@ clone_bounds!( AxisChunksIter['a, A, D] { @copy { life, - n_whole_chunks, + partial_chunk_index, } iter, - last_dim, + partial_chunk_dim, } ); @@ -1233,12 +1250,9 @@ fn chunk_iter_parts( let mut inner_dim = v.dim.clone(); inner_dim[axis] = size; - let mut last_dim = v.dim; - last_dim[axis] = if chunk_remainder == 0 { - size - } else { - chunk_remainder - }; + let mut partial_chunk_dim = v.dim; + partial_chunk_dim[axis] = chunk_remainder; + let partial_chunk_index = n_whole_chunks; let iter = AxisIterCore { index: 0, @@ -1249,16 +1263,16 @@ fn chunk_iter_parts( ptr: v.ptr, }; - (iter, n_whole_chunks, last_dim) + (iter, partial_chunk_index, partial_chunk_dim) } impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> { pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self { - let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size); + let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size); AxisChunksIter { iter, - n_whole_chunks, - last_dim, + partial_chunk_index, + partial_chunk_dim, life: PhantomData, } } @@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl { where D: Dimension, { - fn get_subview( - &self, - iter_item: Option<*mut A>, - is_uneven: bool, - ) -> Option<$array<'a, A, D>> { - iter_item.map(|ptr| { - if !is_uneven { - unsafe { - $array::new_( - ptr, - self.iter.inner_dim.clone(), - self.iter.inner_strides.clone(), - ) - } - } else { - unsafe { - $array::new_( - ptr, - self.last_dim.clone(), - self.iter.inner_strides.clone(), - ) - } + fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> { + if index != self.partial_chunk_index { + unsafe { + $array::new_( + ptr, + self.iter.inner_dim.clone(), + self.iter.inner_strides.clone(), + ) + } + } else { + unsafe { + $array::new_( + ptr, + self.partial_chunk_dim.clone(), + self.iter.inner_strides.clone(), + ) } - }) + } + } + + /// Splits the iterator at index, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. + pub fn split_at(self, index: usize) -> (Self, Self) { + let (left, right) = self.iter.split_at(index); + ( + Self { + iter: left, + partial_chunk_index: self.partial_chunk_index, + partial_chunk_dim: self.partial_chunk_dim.clone(), + life: self.life, + }, + Self { + iter: right, + partial_chunk_index: self.partial_chunk_index, + partial_chunk_dim: self.partial_chunk_dim, + life: self.life, + }, + ) } } @@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl { type Item = $array<'a, A, D>; fn next(&mut self) -> Option { - let res = self.iter.next(); - let is_uneven = self.iter.index > self.n_whole_chunks; - self.get_subview(res, is_uneven) + self.iter + .next_with_index() + .map(|(index, ptr)| self.get_subview(index, ptr)) } fn size_hint(&self) -> (usize, Option) { @@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl { D: Dimension, { fn next_back(&mut self) -> Option { - let is_uneven = self.iter.end > self.n_whole_chunks; - let res = self.iter.next_back(); - self.get_subview(res, is_uneven) + self.iter + .next_back_with_index() + .map(|(index, ptr)| self.get_subview(index, ptr)) } } @@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl { /// for more information. pub struct AxisChunksIterMut<'a, A, D> { iter: AxisIterCore, - n_whole_chunks: usize, - last_dim: D, + partial_chunk_index: usize, + partial_chunk_dim: D, life: PhantomData<&'a mut A>, } impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self { - let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size); + let (iter, partial_chunk_index, partial_chunk_dim) = + chunk_iter_parts(v.into_view(), axis, size); AxisChunksIterMut { iter, - n_whole_chunks: len, - last_dim, + partial_chunk_index, + partial_chunk_dim, life: PhantomData, } } diff --git a/tests/iterators.rs b/tests/iterators.rs index 6408b2f8d..325aa9797 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -13,6 +13,20 @@ use itertools::assert_equal; use itertools::{enumerate, rev}; use std::iter::FromIterator; +macro_rules! assert_panics { + ($body:expr) => { + if let Ok(v) = ::std::panic::catch_unwind(|| $body) { + panic!("assertion failed: should_panic; \ + non-panicking result: {:?}", v); + } + }; + ($body:expr, $($arg:tt)*) => { + if let Ok(_) = ::std::panic::catch_unwind(|| $body) { + panic!($($arg)*); + } + }; +} + #[test] fn double_ended() { let a = ArcArray::linspace(0., 7., 8); @@ -585,6 +599,33 @@ fn axis_chunks_iter_zero_axis_len() { assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none()); } +#[test] +fn axis_chunks_iter_split_at() { + let mut a = Array2::::zeros((11, 3)); + a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i); + for source in &[ + a.slice(s![..0, ..]), + a.slice(s![..1, ..]), + a.slice(s![..5, ..]), + a.slice(s![..10, ..]), + a.slice(s![..11, ..]), + a.slice(s![.., ..0]), + ] { + let chunks_iter = source.axis_chunks_iter(Axis(0), 5); + let all_chunks: Vec<_> = chunks_iter.clone().collect(); + let n_chunks = chunks_iter.len(); + assert_eq!(n_chunks, all_chunks.len()); + for index in 0..=n_chunks { + let (left, right) = chunks_iter.clone().split_at(index); + assert_eq!(&all_chunks[..index], &left.collect::>()[..]); + assert_eq!(&all_chunks[index..], &right.collect::>()[..]); + } + assert_panics!({ + chunks_iter.split_at(n_chunks + 1); + }); + } +} + #[test] fn axis_chunks_iter_mut() { let a = ArcArray::from_iter(0..24);