From fc7e7d8cd6158c127e459ccc55fc9454aa6a5417 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Sun, 28 Feb 2016 18:49:16 +0100 Subject: [PATCH 1/6] add axis_split_at to split an array along an axis --- src/lib.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ tests/array.rs | 10 ++++++++++ 2 files changed, 61 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index f840aaa5f..da11bd72c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1439,6 +1439,57 @@ impl ArrayBase where S: Data, D: Dimension iterators::new_axis_iter_mut(self.view_mut(), axis) } + /// Split the array along `axis` and return one view strictly before the + /// split and one view after the split. + /// + /// **Panics** if `axis` is out of bounds. + pub fn axis_split_at(&self, axis: usize, index: Ix) + -> (ArrayView, ArrayView) + { + assert!(index <= self.shape()[axis]); + let left_ptr = self.ptr; + let right_ptr = if index == self.shape()[axis] { + self.ptr + } + else { + let mut indices = self.dim.clone(); + for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { + if ax != axis { + *ind = 0; + } + else { + *ind = index; + } + } + let offset = self.dim.stride_offset_checked(&self.strides, + &indices).unwrap(); + unsafe { + self.ptr.offset(offset) + } + }; + + let mut dim_left = self.dim.clone(); + dim_left.slice_mut()[axis] = index; + let left = ArrayView { + data: ViewRepr::new(), + ptr: left_ptr, + dim: dim_left, + strides: self.strides.clone() + }; + + let mut dim_right = self.dim.clone(); + dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index; + let right = ArrayView { + data: ViewRepr::new(), + ptr: right_ptr, + dim: dim_right, + strides: self.strides.clone() + }; + + (left, right) + } + + /// Return an iterator that traverses over `axis` by chunks of `size`, /// yielding non-overlapping views along that axis. /// diff --git a/tests/array.rs b/tests/array.rs index 0274dd428..10310099b 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -658,3 +658,13 @@ fn deny_wraparound_reshape() { let five = OwnedArray::::zeros(5); let _five_large = five.into_shape((3, 7, 29, 36760123, 823996703)).unwrap(); } + +#[test] +fn split_at() { + let a = arr2(&[[1., 2.], [3., 4.]]); + + let (c0, c1) = a.axis_split_at(1, 1); + + assert_eq!(c0, arr2(&[[1.], [3.]])); + assert_eq!(c1, arr2(&[[2.], [4.]])); +} From f79135d54e35419d48510705238fc30043918468 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Sun, 28 Feb 2016 19:02:47 +0100 Subject: [PATCH 2/6] make axis_split_at methods take views by value --- src/lib.rs | 140 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 90 insertions(+), 50 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index da11bd72c..58df9e437 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -923,6 +923,51 @@ impl<'a, A, D> ArrayView<'a, A, D> { iterators::new_outer_iter(self) } + + /// Split the array along `axis` and return one view strictly before the + /// split and one view after the split. + /// + /// **Panics** if `axis` is out of bounds. + pub fn axis_split_at(self, axis: usize, index: Ix) + -> (Self, Self) + { + assert!(index <= self.shape()[axis]); + let left_ptr = self.ptr; + let right_ptr = if index == self.shape()[axis] { + self.ptr + } + else { + let mut indices = self.dim.clone(); + for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { + if ax != axis { + *ind = 0; + } + else { + *ind = index; + } + } + let offset = self.dim.stride_offset_checked(&self.strides, + &indices).unwrap(); + unsafe { + self.ptr.offset(offset) + } + }; + + let mut dim_left = self.dim.clone(); + dim_left.slice_mut()[axis] = index; + let left = unsafe { + Self::new_(left_ptr, dim_left, self.strides.clone()) + }; + + let mut dim_right = self.dim.clone(); + dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index; + let right = unsafe { + Self::new_(right_ptr, dim_right, self.strides.clone()) + }; + + (left, right) + } + } impl<'a, A, D> ArrayViewMut<'a, A, D> @@ -1018,6 +1063,51 @@ impl<'a, A, D> ArrayViewMut<'a, A, D> { iterators::new_outer_iter_mut(self) } + + /// Split the array along `axis` and return one mutable view strictly + /// before the split and one mutable view after the split. + /// + /// **Panics** if `axis` is out of bounds. + pub fn axis_split_at(self, axis: usize, index: Ix) + -> (Self, Self) + { + assert!(index <= self.shape()[axis]); + let left_ptr = self.ptr; + let right_ptr = if index == self.shape()[axis] { + self.ptr + } + else { + let mut indices = self.dim.clone(); + for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { + if ax != axis { + *ind = 0; + } + else { + *ind = index; + } + } + let offset = self.dim.stride_offset_checked(&self.strides, + &indices).unwrap(); + unsafe { + self.ptr.offset(offset) + } + }; + + let mut dim_left = self.dim.clone(); + dim_left.slice_mut()[axis] = index; + let left = unsafe { + Self::new_(left_ptr, dim_left, self.strides.clone()) + }; + + let mut dim_right = self.dim.clone(); + dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index; + let right = unsafe { + Self::new_(right_ptr, dim_right, self.strides.clone()) + }; + + (left, right) + } + } impl ArrayBase where S: Data, D: Dimension @@ -1439,56 +1529,6 @@ impl ArrayBase where S: Data, D: Dimension iterators::new_axis_iter_mut(self.view_mut(), axis) } - /// Split the array along `axis` and return one view strictly before the - /// split and one view after the split. - /// - /// **Panics** if `axis` is out of bounds. - pub fn axis_split_at(&self, axis: usize, index: Ix) - -> (ArrayView, ArrayView) - { - assert!(index <= self.shape()[axis]); - let left_ptr = self.ptr; - let right_ptr = if index == self.shape()[axis] { - self.ptr - } - else { - let mut indices = self.dim.clone(); - for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { - if ax != axis { - *ind = 0; - } - else { - *ind = index; - } - } - let offset = self.dim.stride_offset_checked(&self.strides, - &indices).unwrap(); - unsafe { - self.ptr.offset(offset) - } - }; - - let mut dim_left = self.dim.clone(); - dim_left.slice_mut()[axis] = index; - let left = ArrayView { - data: ViewRepr::new(), - ptr: left_ptr, - dim: dim_left, - strides: self.strides.clone() - }; - - let mut dim_right = self.dim.clone(); - dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index; - let right = ArrayView { - data: ViewRepr::new(), - ptr: right_ptr, - dim: dim_right, - strides: self.strides.clone() - }; - - (left, right) - } - /// Return an iterator that traverses over `axis` by chunks of `size`, /// yielding non-overlapping views along that axis. From 4f50238718ffdd48ab592c2bb17212e2b4fe3ab0 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Sun, 28 Feb 2016 19:37:59 +0100 Subject: [PATCH 3/6] add tests for mutable cases and more complex shapes --- tests/array.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/array.rs b/tests/array.rs index 10310099b..fa82f38e2 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -661,10 +661,26 @@ fn deny_wraparound_reshape() { #[test] fn split_at() { - let a = arr2(&[[1., 2.], [3., 4.]]); + let mut a = arr2(&[[1., 2.], [3., 4.]]); + + { + let (c0, c1) = a.view().axis_split_at(1, 1); + + assert_eq!(c0, arr2(&[[1.], [3.]])); + assert_eq!(c1, arr2(&[[2.], [4.]])); + } + + { + let (mut r0, mut r1) = a.view_mut().axis_split_at(0, 1); + r0[[0, 1]] = 5.; + r1[[0, 0]] = 8.; + } + assert_eq!(a, arr2(&[[1., 5.], [8., 4.]])); + - let (c0, c1) = a.axis_split_at(1, 1); + let b = RcArray::linspace(0., 59., 60).reshape((3, 4, 5)); - assert_eq!(c0, arr2(&[[1.], [3.]])); - assert_eq!(c1, arr2(&[[2.], [4.]])); + let (left, right) = b.view().axis_split_at(2, 2); + assert_eq!(left.shape(), [3, 4, 2]); + assert_eq!(right.shape(), [3, 4, 3]); } From e0527a74ff885c80c6c79825fe54ba6ab99ca69c Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Mon, 29 Feb 2016 09:43:48 +0100 Subject: [PATCH 4/6] formatting fix --- src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 58df9e437..444c9c6a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -935,8 +935,7 @@ impl<'a, A, D> ArrayView<'a, A, D> let left_ptr = self.ptr; let right_ptr = if index == self.shape()[axis] { self.ptr - } - else { + } else { let mut indices = self.dim.clone(); for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { if ax != axis { From a188497892f438ede141dac6142a234526267274 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Mon, 29 Feb 2016 13:03:30 +0100 Subject: [PATCH 5/6] simplify split_at pointer offseting --- src/lib.rs | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 444c9c6a3..fc5fd459d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,8 @@ pub use dimension::{ RemoveAxis, }; +use dimension::stride_offset; + pub use dimension::NdIndex; pub use indexes::Indexes; pub use shape_error::ShapeError; @@ -936,17 +938,7 @@ impl<'a, A, D> ArrayView<'a, A, D> let right_ptr = if index == self.shape()[axis] { self.ptr } else { - let mut indices = self.dim.clone(); - for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { - if ax != axis { - *ind = 0; - } - else { - *ind = index; - } - } - let offset = self.dim.stride_offset_checked(&self.strides, - &indices).unwrap(); + let offset = stride_offset(index, self.strides.slice()[axis]); unsafe { self.ptr.offset(offset) } @@ -1076,17 +1068,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D> self.ptr } else { - let mut indices = self.dim.clone(); - for (ax, ind) in indices.slice_mut().iter_mut().enumerate() { - if ax != axis { - *ind = 0; - } - else { - *ind = index; - } - } - let offset = self.dim.stride_offset_checked(&self.strides, - &indices).unwrap(); + let offset = stride_offset(index, self.strides.slice()[axis]); unsafe { self.ptr.offset(offset) } From 53a7dc514cf2e68aaae501bcf9067ed3cb4593fd Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Mon, 29 Feb 2016 13:24:40 +0100 Subject: [PATCH 6/6] add more split_at tests --- tests/array.rs | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/array.rs b/tests/array.rs index fa82f38e2..63f1e9496 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -7,7 +7,7 @@ extern crate ndarray; use ndarray::{RcArray, S, Si, OwnedArray, }; -use ndarray::{arr0, arr1, arr2, +use ndarray::{arr0, arr1, arr2, arr3, aview0, aview1, aview2, @@ -683,4 +683,25 @@ fn split_at() { let (left, right) = b.view().axis_split_at(2, 2); assert_eq!(left.shape(), [3, 4, 2]); assert_eq!(right.shape(), [3, 4, 3]); + assert_eq!(left, arr3(&[[[0., 1.], [5., 6.], [10., 11.], [15., 16.]], + [[20., 21.], [25., 26.], [30., 31.], [35., 36.]], + [[40., 41.], [45., 46.], [50., 51.], [55., 56.]]])); + + // we allow for an empty right view when index == dim[axis] + let (_, right) = b.view().axis_split_at(1, 4); + assert_eq!(right.shape(), [3, 0, 5]); +} + +#[test] +#[should_panic] +fn deny_split_at_axis_out_of_bounds() { + let a = arr2(&[[1., 2.], [3., 4.]]); + a.view().axis_split_at(2, 0); +} + +#[test] +#[should_panic] +fn deny_split_at_index_out_of_bounds() { + let a = arr2(&[[1., 2.], [3., 4.]]); + a.view().axis_split_at(1, 3); }