diff --git a/src/lib.rs b/src/lib.rs index f840aaa5f..fc5fd459d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,8 @@ pub use dimension::{ RemoveAxis, }; +use dimension::stride_offset; + pub use dimension::NdIndex; pub use indexes::Indexes; pub use shape_error::ShapeError; @@ -923,6 +925,40 @@ impl<'a, A, D> ArrayView<'a, A, D> { iterators::new_outer_iter(self) } + + /// Split the array along `axis` and return one view strictly before the + /// split and one view after the split. + /// + /// **Panics** if `axis` is out of bounds. + pub fn axis_split_at(self, axis: usize, index: Ix) + -> (Self, Self) + { + assert!(index <= self.shape()[axis]); + let left_ptr = self.ptr; + let right_ptr = if index == self.shape()[axis] { + self.ptr + } else { + let offset = stride_offset(index, self.strides.slice()[axis]); + unsafe { + self.ptr.offset(offset) + } + }; + + let mut dim_left = self.dim.clone(); + dim_left.slice_mut()[axis] = index; + let left = unsafe { + Self::new_(left_ptr, dim_left, self.strides.clone()) + }; + + let mut dim_right = self.dim.clone(); + dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index; + let right = unsafe { + Self::new_(right_ptr, dim_right, self.strides.clone()) + }; + + (left, right) + } + } impl<'a, A, D> ArrayViewMut<'a, A, D> @@ -1018,6 +1054,41 @@ impl<'a, A, D> ArrayViewMut<'a, A, D> { iterators::new_outer_iter_mut(self) } + + /// Split the array along `axis` and return one mutable view strictly + /// before the split and one mutable view after the split. + /// + /// **Panics** if `axis` is out of bounds. + pub fn axis_split_at(self, axis: usize, index: Ix) + -> (Self, Self) + { + assert!(index <= self.shape()[axis]); + let left_ptr = self.ptr; + let right_ptr = if index == self.shape()[axis] { + self.ptr + } + else { + let offset = stride_offset(index, self.strides.slice()[axis]); + unsafe { + self.ptr.offset(offset) + } + }; + + let mut dim_left = self.dim.clone(); + dim_left.slice_mut()[axis] = index; + let left = unsafe { + Self::new_(left_ptr, dim_left, self.strides.clone()) + }; + + let mut dim_right = self.dim.clone(); + dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index; + let right = unsafe { + Self::new_(right_ptr, dim_right, self.strides.clone()) + }; + + (left, right) + } + } impl ArrayBase where S: Data, D: Dimension @@ -1439,6 +1510,7 @@ impl ArrayBase where S: Data, D: Dimension iterators::new_axis_iter_mut(self.view_mut(), axis) } + /// Return an iterator that traverses over `axis` by chunks of `size`, /// yielding non-overlapping views along that axis. /// diff --git a/tests/array.rs b/tests/array.rs index 0274dd428..63f1e9496 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -7,7 +7,7 @@ extern crate ndarray; use ndarray::{RcArray, S, Si, OwnedArray, }; -use ndarray::{arr0, arr1, arr2, +use ndarray::{arr0, arr1, arr2, arr3, aview0, aview1, aview2, @@ -658,3 +658,50 @@ fn deny_wraparound_reshape() { let five = OwnedArray::::zeros(5); let _five_large = five.into_shape((3, 7, 29, 36760123, 823996703)).unwrap(); } + +#[test] +fn split_at() { + let mut a = arr2(&[[1., 2.], [3., 4.]]); + + { + let (c0, c1) = a.view().axis_split_at(1, 1); + + assert_eq!(c0, arr2(&[[1.], [3.]])); + assert_eq!(c1, arr2(&[[2.], [4.]])); + } + + { + let (mut r0, mut r1) = a.view_mut().axis_split_at(0, 1); + r0[[0, 1]] = 5.; + r1[[0, 0]] = 8.; + } + assert_eq!(a, arr2(&[[1., 5.], [8., 4.]])); + + + let b = RcArray::linspace(0., 59., 60).reshape((3, 4, 5)); + + let (left, right) = b.view().axis_split_at(2, 2); + assert_eq!(left.shape(), [3, 4, 2]); + assert_eq!(right.shape(), [3, 4, 3]); + assert_eq!(left, arr3(&[[[0., 1.], [5., 6.], [10., 11.], [15., 16.]], + [[20., 21.], [25., 26.], [30., 31.], [35., 36.]], + [[40., 41.], [45., 46.], [50., 51.], [55., 56.]]])); + + // we allow for an empty right view when index == dim[axis] + let (_, right) = b.view().axis_split_at(1, 4); + assert_eq!(right.shape(), [3, 0, 5]); +} + +#[test] +#[should_panic] +fn deny_split_at_axis_out_of_bounds() { + let a = arr2(&[[1., 2.], [3., 4.]]); + a.view().axis_split_at(2, 0); +} + +#[test] +#[should_panic] +fn deny_split_at_index_out_of_bounds() { + let a = arr2(&[[1., 2.], [3., 4.]]); + a.view().axis_split_at(1, 3); +}