Skip to content

split_at for Array #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 29, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ pub use dimension::{
RemoveAxis,
};

use dimension::stride_offset;

pub use dimension::NdIndex;
pub use indexes::Indexes;
pub use shape_error::ShapeError;
Expand Down Expand Up @@ -923,6 +925,40 @@ impl<'a, A, D> ArrayView<'a, A, D>
{
iterators::new_outer_iter(self)
}

/// Split the array along `axis` and return one view strictly before the
/// split and one view after the split.
///
/// **Panics** if `axis` is out of bounds.
pub fn axis_split_at(self, axis: usize, index: Ix)
-> (Self, Self)
{
assert!(index <= self.shape()[axis]);
let left_ptr = self.ptr;
let right_ptr = if index == self.shape()[axis] {
self.ptr
} else {
let offset = stride_offset(index, self.strides.slice()[axis]);
unsafe {
self.ptr.offset(offset)
}
};

let mut dim_left = self.dim.clone();
dim_left.slice_mut()[axis] = index;
let left = unsafe {
Self::new_(left_ptr, dim_left, self.strides.clone())
};

let mut dim_right = self.dim.clone();
dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index;
let right = unsafe {
Self::new_(right_ptr, dim_right, self.strides.clone())
};

(left, right)
}

}

impl<'a, A, D> ArrayViewMut<'a, A, D>
Expand Down Expand Up @@ -1018,6 +1054,41 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
{
iterators::new_outer_iter_mut(self)
}

/// Split the array along `axis` and return one mutable view strictly
/// before the split and one mutable view after the split.
///
/// **Panics** if `axis` is out of bounds.
pub fn axis_split_at(self, axis: usize, index: Ix)
-> (Self, Self)
{
assert!(index <= self.shape()[axis]);
let left_ptr = self.ptr;
let right_ptr = if index == self.shape()[axis] {
self.ptr
}
else {
let offset = stride_offset(index, self.strides.slice()[axis]);
unsafe {
self.ptr.offset(offset)
}
};

let mut dim_left = self.dim.clone();
dim_left.slice_mut()[axis] = index;
let left = unsafe {
Self::new_(left_ptr, dim_left, self.strides.clone())
};

let mut dim_right = self.dim.clone();
dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index;
let right = unsafe {
Self::new_(right_ptr, dim_right, self.strides.clone())
};

(left, right)
}

}

impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
Expand Down Expand Up @@ -1439,6 +1510,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
iterators::new_axis_iter_mut(self.view_mut(), axis)
}


/// Return an iterator that traverses over `axis` by chunks of `size`,
/// yielding non-overlapping views along that axis.
///
Expand Down
49 changes: 48 additions & 1 deletion tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate ndarray;
use ndarray::{RcArray, S, Si,
OwnedArray,
};
use ndarray::{arr0, arr1, arr2,
use ndarray::{arr0, arr1, arr2, arr3,
aview0,
aview1,
aview2,
Expand Down Expand Up @@ -658,3 +658,50 @@ fn deny_wraparound_reshape() {
let five = OwnedArray::<f32, _>::zeros(5);
let _five_large = five.into_shape((3, 7, 29, 36760123, 823996703)).unwrap();
}

#[test]
fn split_at() {
let mut a = arr2(&[[1., 2.], [3., 4.]]);

{
let (c0, c1) = a.view().axis_split_at(1, 1);

assert_eq!(c0, arr2(&[[1.], [3.]]));
assert_eq!(c1, arr2(&[[2.], [4.]]));
}

{
let (mut r0, mut r1) = a.view_mut().axis_split_at(0, 1);
r0[[0, 1]] = 5.;
r1[[0, 0]] = 8.;
}
assert_eq!(a, arr2(&[[1., 5.], [8., 4.]]));


let b = RcArray::linspace(0., 59., 60).reshape((3, 4, 5));

let (left, right) = b.view().axis_split_at(2, 2);
assert_eq!(left.shape(), [3, 4, 2]);
assert_eq!(right.shape(), [3, 4, 3]);
assert_eq!(left, arr3(&[[[0., 1.], [5., 6.], [10., 11.], [15., 16.]],
[[20., 21.], [25., 26.], [30., 31.], [35., 36.]],
[[40., 41.], [45., 46.], [50., 51.], [55., 56.]]]));

// we allow for an empty right view when index == dim[axis]
let (_, right) = b.view().axis_split_at(1, 4);
assert_eq!(right.shape(), [3, 0, 5]);
}

#[test]
#[should_panic]
fn deny_split_at_axis_out_of_bounds() {
let a = arr2(&[[1., 2.], [3., 4.]]);
a.view().axis_split_at(2, 0);
}

#[test]
#[should_panic]
fn deny_split_at_index_out_of_bounds() {
let a = arr2(&[[1., 2.], [3., 4.]]);
a.view().axis_split_at(1, 3);
}