From f8617321a20261490d38b6e5cf0136164108f608 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 28 Jun 2017 21:21:10 +0900 Subject: [PATCH 1/6] Rename shape converter --- src/layout.rs | 6 +++--- src/triangular.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layout.rs b/src/layout.rs index e4a57725..9f29409b 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -126,7 +126,7 @@ where } } -pub fn into_col_vec(a: ArrayBase) -> ArrayBase +pub fn into_col(a: ArrayBase) -> ArrayBase where S: Data, { @@ -134,7 +134,7 @@ where a.into_shape((n, 1)).unwrap() } -pub fn into_row_vec(a: ArrayBase) -> ArrayBase +pub fn into_row(a: ArrayBase) -> ArrayBase where S: Data, { @@ -142,7 +142,7 @@ where a.into_shape((1, n)).unwrap() } -pub fn into_vec(a: ArrayBase) -> ArrayBase +pub fn flatten(a: ArrayBase) -> ArrayBase where S: Data, { diff --git a/src/triangular.rs b/src/triangular.rs index d3626b93..523c73c3 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -73,9 +73,9 @@ where type Output = ArrayBase; fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: ArrayBase) -> Result { - let b = into_col_vec(b); + let b = into_col(b); let b = self.solve_triangular(uplo, diag, b)?; - Ok(into_vec(b)) + Ok(flatten(b)) } } From 1a283c66534a526f65f208175dd4d461b2595de1 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 28 Jun 2017 21:26:06 +0900 Subject: [PATCH 2/6] Split converter into convert sub module --- src/cholesky.rs | 1 + src/convert.rs | 77 +++++++++++++++++++++++++++++++++++++++++++++++ src/eigh.rs | 1 + src/generate.rs | 1 + src/layout.rs | 75 +-------------------------------------------- src/lib.rs | 1 + src/qr.rs | 1 + src/solve.rs | 5 ++- src/svd.rs | 3 +- src/triangular.rs | 1 + 10 files changed, 90 insertions(+), 76 deletions(-) create mode 100644 src/convert.rs diff --git a/src/cholesky.rs b/src/cholesky.rs index 93f79bd6..5694ddfc 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -3,6 +3,7 @@ use ndarray::*; use num_traits::Zero; +use super::convert::*; use super::error::*; use super::layout::*; use super::triangular::IntoTriangular; diff --git a/src/convert.rs b/src/convert.rs new file mode 100644 index 00000000..a560f1fc --- /dev/null +++ b/src/convert.rs @@ -0,0 +1,77 @@ +use ndarray::*; + +use super::error::*; +use super::layout::*; + +pub fn into_col(a: ArrayBase) -> ArrayBase +where + S: Data, +{ + let n = a.len(); + a.into_shape((n, 1)).unwrap() +} + +pub fn into_row(a: ArrayBase) -> ArrayBase +where + S: Data, +{ + let n = a.len(); + a.into_shape((1, n)).unwrap() +} + +pub fn flatten(a: ArrayBase) -> ArrayBase +where + S: Data, +{ + let n = a.len(); + a.into_shape((n)).unwrap() +} + +pub fn reconstruct(l: Layout, a: Vec) -> Result> +where + S: DataOwned, +{ + Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?) +} + +pub fn uninitialized(l: Layout) -> ArrayBase +where + A: Copy, + S: DataOwned, +{ + unsafe { ArrayBase::uninitialized(l.as_shape()) } +} + +pub fn replicate(a: &ArrayBase) -> ArrayBase +where + A: Copy, + Sv: Data, + So: DataOwned + DataMut, + D: Dimension, +{ + let mut b = unsafe { ArrayBase::uninitialized(a.dim()) }; + b.assign(a); + b +} + +pub fn clone_with_layout(l: Layout, a: &ArrayBase) -> ArrayBase +where + A: Copy, + Si: Data, + So: DataOwned + DataMut, +{ + let mut b = uninitialized(l); + b.assign(a); + b +} + +pub fn data_transpose(a: &mut ArrayBase) -> Result<&mut ArrayBase> +where + A: Copy, + S: DataOwned + DataMut, +{ + let l = a.layout()?.toggle_order(); + let new = clone_with_layout(l, a); + ::std::mem::replace(a, new); + Ok(a) +} diff --git a/src/eigh.rs b/src/eigh.rs index 4a768138..8229385e 100644 --- a/src/eigh.rs +++ b/src/eigh.rs @@ -2,6 +2,7 @@ use ndarray::*; +use super::convert::*; use super::error::*; use super::layout::*; diff --git a/src/generate.rs b/src/generate.rs index a83ef1bf..73258462 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -4,6 +4,7 @@ use ndarray::*; use rand::*; use std::ops::*; +use super::convert::*; use super::error::*; use super::layout::*; use super::types::*; diff --git a/src/layout.rs b/src/layout.rs index 9f29409b..b0f4aa1e 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -100,7 +100,7 @@ where Err(StrideError::new(strides[0], strides[1]).into()) } - fn square_layout(&self) -> Result { + fn square_layout(&self) -> Result { let l = self.layout()?; let (n, m) = l.size(); if n == m { @@ -125,76 +125,3 @@ where )?) } } - -pub fn into_col(a: ArrayBase) -> ArrayBase -where - S: Data, -{ - let n = a.len(); - a.into_shape((n, 1)).unwrap() -} - -pub fn into_row(a: ArrayBase) -> ArrayBase -where - S: Data, -{ - let n = a.len(); - a.into_shape((1, n)).unwrap() -} - -pub fn flatten(a: ArrayBase) -> ArrayBase -where - S: Data, -{ - let n = a.len(); - a.into_shape((n)).unwrap() -} - -pub fn reconstruct(l: Layout, a: Vec) -> Result> -where - S: DataOwned, -{ - Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?) -} - -pub fn uninitialized(l: Layout) -> ArrayBase -where - A: Copy, - S: DataOwned, -{ - unsafe { ArrayBase::uninitialized(l.as_shape()) } -} - -pub fn replicate(a: &ArrayBase) -> ArrayBase -where - A: Copy, - Sv: Data, - So: DataOwned + DataMut, - D: Dimension, -{ - let mut b = unsafe { ArrayBase::uninitialized(a.dim()) }; - b.assign(a); - b -} - -pub fn clone_with_layout(l: Layout, a: &ArrayBase) -> ArrayBase -where - A: Copy, - Si: Data, - So: DataOwned + DataMut, -{ - let mut b = uninitialized(l); - b.assign(a); - b -} - -pub fn data_transpose(a: &mut ArrayBase) -> Result<&mut ArrayBase> -where - A: Copy, - S: DataOwned + DataMut, -{ - let l = a.layout()?.toggle_order(); - let new = clone_with_layout(l, a); - ::std::mem::replace(a, new); - Ok(a) -} diff --git a/src/lib.rs b/src/lib.rs index 135de75c..7c48ec4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,7 @@ pub mod solve; pub mod svd; pub mod triangular; +pub mod convert; pub mod generate; pub mod assert; pub mod norm; diff --git a/src/qr.rs b/src/qr.rs index 7c66907c..9a526366 100644 --- a/src/qr.rs +++ b/src/qr.rs @@ -3,6 +3,7 @@ use ndarray::*; use num_traits::Zero; +use super::convert::*; use super::error::*; use super::layout::*; diff --git a/src/solve.rs b/src/solve.rs index 985c58fd..5a7dc0b2 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -1,9 +1,12 @@ //! Solve linear problems + +use ndarray::*; + +use super::convert::*; use super::error::*; use super::lapack_traits::*; use super::layout::*; -use ndarray::*; pub use lapack_traits::{Pivot, Transpose}; diff --git a/src/svd.rs b/src/svd.rs index cecc8644..d1170bb6 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -2,9 +2,10 @@ use ndarray::*; +use super::convert::*; use super::error::*; +use super::lapack_traits::LapackScalar; use super::layout::*; -use lapack_traits::LapackScalar; pub trait SVD { fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option, S, Option)>; diff --git a/src/triangular.rs b/src/triangular.rs index 523c73c3..9244a495 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -3,6 +3,7 @@ use ndarray::*; use num_traits::Zero; +use super::convert::*; use super::error::*; use super::lapack_traits::*; use super::layout::*; From 7c886853d8adf88f315e8b5685a0dec1213b1fa4 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 28 Jun 2017 21:30:09 +0900 Subject: [PATCH 3/6] Rename Layout into MatrixLayout to avoid name collision --- src/convert.rs | 6 ++--- src/generate.rs | 1 - src/lapack_traits/cholesky.rs | 6 ++--- src/lapack_traits/eigh.rs | 6 ++--- src/lapack_traits/opnorm.rs | 10 +++---- src/lapack_traits/qr.rs | 14 +++++----- src/lapack_traits/solve.rs | 14 +++++----- src/lapack_traits/svd.rs | 6 ++--- src/lapack_traits/triangular.rs | 10 +++---- src/layout.rs | 46 ++++++++++++++++----------------- tests/layout.rs | 10 +++---- 11 files changed, 64 insertions(+), 65 deletions(-) diff --git a/src/convert.rs b/src/convert.rs index a560f1fc..dcf1f64b 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -27,14 +27,14 @@ where a.into_shape((n)).unwrap() } -pub fn reconstruct(l: Layout, a: Vec) -> Result> +pub fn reconstruct(l: MatrixLayout, a: Vec) -> Result> where S: DataOwned, { Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?) } -pub fn uninitialized(l: Layout) -> ArrayBase +pub fn uninitialized(l: MatrixLayout) -> ArrayBase where A: Copy, S: DataOwned, @@ -54,7 +54,7 @@ where b } -pub fn clone_with_layout(l: Layout, a: &ArrayBase) -> ArrayBase +pub fn clone_with_layout(l: MatrixLayout, a: &ArrayBase) -> ArrayBase where A: Copy, Si: Data, diff --git a/src/generate.rs b/src/generate.rs index 73258462..dc04c8e7 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -6,7 +6,6 @@ use std::ops::*; use super::convert::*; use super::error::*; -use super::layout::*; use super::types::*; /// Hermite conjugate matrix diff --git a/src/lapack_traits/cholesky.rs b/src/lapack_traits/cholesky.rs index c4e539c3..94767d97 100644 --- a/src/lapack_traits/cholesky.rs +++ b/src/lapack_traits/cholesky.rs @@ -3,19 +3,19 @@ use lapack::c; use error::*; -use layout::Layout; +use layout::MatrixLayout; use types::*; use super::{UPLO, into_result}; pub trait Cholesky_: Sized { - fn cholesky(Layout, UPLO, a: &mut [Self]) -> Result<()>; + fn cholesky(MatrixLayout, UPLO, a: &mut [Self]) -> Result<()>; } macro_rules! impl_cholesky { ($scalar:ty, $potrf:path) => { impl Cholesky_ for $scalar { - fn cholesky(l: Layout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> { + fn cholesky(l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let info = $potrf(l.lapacke_layout(), uplo as u8, n, &mut a, n); into_result(info, ()) diff --git a/src/lapack_traits/eigh.rs b/src/lapack_traits/eigh.rs index a5e4f8c6..490660f5 100644 --- a/src/lapack_traits/eigh.rs +++ b/src/lapack_traits/eigh.rs @@ -4,20 +4,20 @@ use lapack::c; use num_traits::Zero; use error::*; -use layout::Layout; +use layout::MatrixLayout; use types::*; use super::{UPLO, into_result}; /// Wraps `*syev` for real and `*heev` for complex pub trait Eigh_: AssociatedReal { - fn eigh(calc_eigenvec: bool, Layout, UPLO, a: &mut [Self]) -> Result>; + fn eigh(calc_eigenvec: bool, MatrixLayout, UPLO, a: &mut [Self]) -> Result>; } macro_rules! impl_eigh { ($scalar:ty, $ev:path) => { impl Eigh_ for $scalar { - fn eigh(calc_v: bool, l: Layout, uplo: UPLO, mut a: &mut [Self]) -> Result> { + fn eigh(calc_v: bool, l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result> { let (n, _) = l.size(); let jobz = if calc_v { b'V' } else { b'N' }; let mut w = vec![Self::Real::zero(); n as usize]; diff --git a/src/lapack_traits/opnorm.rs b/src/lapack_traits/opnorm.rs index 294c5f8b..50c5cde1 100644 --- a/src/lapack_traits/opnorm.rs +++ b/src/lapack_traits/opnorm.rs @@ -3,7 +3,7 @@ use lapack::c; use lapack::c::Layout::ColumnMajor as cm; -use layout::Layout; +use layout::MatrixLayout; use types::*; #[repr(u8)] @@ -24,16 +24,16 @@ impl NormType { } pub trait OperatorNorm_: AssociatedReal { - fn opnorm(NormType, Layout, &[Self]) -> Self::Real; + fn opnorm(NormType, MatrixLayout, &[Self]) -> Self::Real; } macro_rules! impl_opnorm { ($scalar:ty, $lange:path) => { impl OperatorNorm_ for $scalar { - fn opnorm(t: NormType, l: Layout, a: &[Self]) -> Self::Real { + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { match l { - Layout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda), - Layout::C((row, lda)) => $lange(cm, t.transpose() as u8, lda, row, a, lda), + MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda), + MatrixLayout::C((row, lda)) => $lange(cm, t.transpose() as u8, lda, row, a, lda), } } } diff --git a/src/lapack_traits/qr.rs b/src/lapack_traits/qr.rs index 481533ca..1f06eb1d 100644 --- a/src/lapack_traits/qr.rs +++ b/src/lapack_traits/qr.rs @@ -5,22 +5,22 @@ use num_traits::Zero; use std::cmp::min; use error::*; -use layout::Layout; +use layout::MatrixLayout; use types::*; use super::into_result; /// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers) pub trait QR_: Sized { - fn householder(Layout, a: &mut [Self]) -> Result>; - fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>; - fn qr(Layout, a: &mut [Self]) -> Result>; + fn householder(MatrixLayout, a: &mut [Self]) -> Result>; + fn q(MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; + fn qr(MatrixLayout, a: &mut [Self]) -> Result>; } macro_rules! impl_qr { ($scalar:ty, $qrf:path, $gqr:path) => { impl QR_ for $scalar { - fn householder(l: Layout, mut a: &mut [Self]) -> Result> { + fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { let (row, col) = l.size(); let k = min(row, col); let mut tau = vec![Self::zero(); k as usize]; @@ -28,14 +28,14 @@ impl QR_ for $scalar { into_result(info, tau) } - fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { let (row, col) = l.size(); let k = min(row, col); let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau); into_result(info, ()) } - fn qr(l: Layout, mut a: &mut [Self]) -> Result> { + fn qr(l: MatrixLayout, mut a: &mut [Self]) -> Result> { let tau = Self::householder(l, a)?; let r = Vec::from(&*a); Self::q(l, a, &tau)?; diff --git a/src/lapack_traits/solve.rs b/src/lapack_traits/solve.rs index ac6b4032..4bda4374 100644 --- a/src/lapack_traits/solve.rs +++ b/src/lapack_traits/solve.rs @@ -3,7 +3,7 @@ use lapack::c; use error::*; -use layout::Layout; +use layout::MatrixLayout; use types::*; use super::{Transpose, into_result}; @@ -12,16 +12,16 @@ pub type Pivot = Vec; /// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Sized { - fn lu(Layout, a: &mut [Self]) -> Result; - fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; - fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; + fn lu(MatrixLayout, a: &mut [Self]) -> Result; + fn inv(MatrixLayout, a: &mut [Self], &Pivot) -> Result<()>; + fn solve(MatrixLayout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar { - fn lu(l: Layout, a: &mut [Self]) -> Result { + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; @@ -29,13 +29,13 @@ impl Solve_ for $scalar { into_result(info, ipiv) } - fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv); into_result(info, ()) } - fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; diff --git a/src/lapack_traits/svd.rs b/src/lapack_traits/svd.rs index 2b488b4b..75233910 100644 --- a/src/lapack_traits/svd.rs +++ b/src/lapack_traits/svd.rs @@ -4,7 +4,7 @@ use lapack::c; use num_traits::Zero; use error::*; -use layout::Layout; +use layout::MatrixLayout; use types::*; use super::into_result; @@ -29,14 +29,14 @@ pub struct SVDOutput { /// Wraps `*gesvd` pub trait SVD_: AssociatedReal { - fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; + fn svd(MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; } macro_rules! impl_svd { ($scalar:ty, $gesvd:path) => { impl SVD_ for $scalar { - fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { let (m, n) = l.size(); let k = ::std::cmp::min(n, m); let lda = l.lda(); diff --git a/src/lapack_traits/triangular.rs b/src/lapack_traits/triangular.rs index 5f5e3e19..e86622e0 100644 --- a/src/lapack_traits/triangular.rs +++ b/src/lapack_traits/triangular.rs @@ -4,7 +4,7 @@ use lapack::c; use super::{Transpose, UPLO, into_result}; use error::*; -use layout::Layout; +use layout::MatrixLayout; use types::*; #[derive(Debug, Clone, Copy)] @@ -16,22 +16,22 @@ pub enum Diag { /// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Sized { - fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>; - fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>; + fn inv_triangular(l: MatrixLayout, UPLO, Diag, a: &mut [Self]) -> Result<()>; + fn solve_triangular(al: MatrixLayout, bl: MatrixLayout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_triangular { ($scalar:ty, $trtri:path, $trtrs:path) => { impl Triangular_ for $scalar { - fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> { + fn inv_triangular(l: MatrixLayout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let lda = l.lda(); let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); into_result(info, ()) } - fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> { + fn solve_triangular(al: MatrixLayout, bl: MatrixLayout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> { let (n, _) = al.size(); let lda = al.lda(); let (_, nrhs) = bl.size(); diff --git a/src/layout.rs b/src/layout.rs index b0f4aa1e..5063be13 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -11,70 +11,70 @@ pub type Col = i32; pub type Row = i32; #[derive(Debug, Clone, Copy, PartialEq)] -pub enum Layout { +pub enum MatrixLayout { C((Row, LDA)), F((Col, LDA)), } -impl Layout { +impl MatrixLayout { pub fn size(&self) -> (Row, Col) { match *self { - Layout::C((row, lda)) => (row, lda), - Layout::F((col, lda)) => (lda, col), + MatrixLayout::C((row, lda)) => (row, lda), + MatrixLayout::F((col, lda)) => (lda, col), } } - pub fn resized(&self, row: Row, col: Col) -> Layout { + pub fn resized(&self, row: Row, col: Col) -> MatrixLayout { match *self { - Layout::C(_) => Layout::C((row, col)), - Layout::F(_) => Layout::F((col, row)), + MatrixLayout::C(_) => MatrixLayout::C((row, col)), + MatrixLayout::F(_) => MatrixLayout::F((col, row)), } } pub fn lda(&self) -> LDA { match *self { - Layout::C((_, lda)) => lda, - Layout::F((_, lda)) => lda, + MatrixLayout::C((_, lda)) => lda, + MatrixLayout::F((_, lda)) => lda, } } pub fn len(&self) -> LEN { match *self { - Layout::C((row, _)) => row, - Layout::F((col, _)) => col, + MatrixLayout::C((row, _)) => row, + MatrixLayout::F((col, _)) => col, } } pub fn lapacke_layout(&self) -> c::Layout { match *self { - Layout::C(_) => c::Layout::RowMajor, - Layout::F(_) => c::Layout::ColumnMajor, + MatrixLayout::C(_) => c::Layout::RowMajor, + MatrixLayout::F(_) => c::Layout::ColumnMajor, } } - pub fn same_order(&self, other: &Layout) -> bool { + pub fn same_order(&self, other: &MatrixLayout) -> bool { self.lapacke_layout() == other.lapacke_layout() } pub fn as_shape(&self) -> Shape { match *self { - Layout::C((row, col)) => (row as usize, col as usize).into_shape(), - Layout::F((col, row)) => (row as usize, col as usize).f().into_shape(), + MatrixLayout::C((row, col)) => (row as usize, col as usize).into_shape(), + MatrixLayout::F((col, row)) => (row as usize, col as usize).f().into_shape(), } } pub fn toggle_order(&self) -> Self { match *self { - Layout::C((row, col)) => Layout::F((col, row)), - Layout::F((col, row)) => Layout::C((row, col)), + MatrixLayout::C((row, col)) => MatrixLayout::F((col, row)), + MatrixLayout::F((col, row)) => MatrixLayout::C((row, col)), } } } pub trait AllocatedArray { type Elem; - fn layout(&self) -> Result; - fn square_layout(&self) -> Result; + fn layout(&self) -> Result; + fn square_layout(&self) -> Result; fn as_allocated(&self) -> Result<&[Self::Elem]>; } @@ -88,14 +88,14 @@ where { type Elem = A; - fn layout(&self) -> Result { + fn layout(&self) -> Result { let shape = self.shape(); let strides = self.strides(); if shape[0] == strides[1] as usize { - return Ok(Layout::F((self.cols() as i32, self.rows() as i32))); + return Ok(MatrixLayout::F((self.cols() as i32, self.rows() as i32))); } if shape[1] == strides[0] as usize { - return Ok(Layout::C((self.rows() as i32, self.cols() as i32))); + return Ok(MatrixLayout::C((self.rows() as i32, self.cols() as i32))); } Err(StrideError::new(strides[0], strides[1]).into()) } diff --git a/tests/layout.rs b/tests/layout.rs index 7615d1e3..bc10838e 100644 --- a/tests/layout.rs +++ b/tests/layout.rs @@ -4,32 +4,32 @@ extern crate ndarray_linalg; use ndarray::*; use ndarray_linalg::*; -use ndarray_linalg::layout::Layout; +use ndarray_linalg::layout::MatrixLayout; #[test] fn layout_c_3x1() { let a: Array2 = Array::zeros((3, 1)); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), Layout::C((3, 1))); + assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 1))); } #[test] fn layout_f_3x1() { let a: Array2 = Array::zeros((3, 1).f()); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), Layout::F((1, 3))); + assert_eq!(a.layout().unwrap(), MatrixLayout::F((1, 3))); } #[test] fn layout_c_3x2() { let a: Array2 = Array::zeros((3, 2)); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), Layout::C((3, 2))); + assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 2))); } #[test] fn layout_f_3x2() { let a: Array2 = Array::zeros((3, 2).f()); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), Layout::F((2, 3))); + assert_eq!(a.layout().unwrap(), MatrixLayout::F((2, 3))); } From f6164375ed7bdc6ac624ddc0f20c9a6ed6337910 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 28 Jun 2017 21:34:13 +0900 Subject: [PATCH 4/6] Rename convert functions --- src/convert.rs | 8 ++++---- src/qr.rs | 4 ++-- src/svd.rs | 4 ++-- src/triangular.rs | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/convert.rs b/src/convert.rs index dcf1f64b..96e7d358 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -27,14 +27,14 @@ where a.into_shape((n)).unwrap() } -pub fn reconstruct(l: MatrixLayout, a: Vec) -> Result> +pub fn into_matrix(l: MatrixLayout, a: Vec) -> Result> where S: DataOwned, { Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?) } -pub fn uninitialized(l: MatrixLayout) -> ArrayBase +fn uninitialized(l: MatrixLayout) -> ArrayBase where A: Copy, S: DataOwned, @@ -54,7 +54,7 @@ where b } -pub fn clone_with_layout(l: MatrixLayout, a: &ArrayBase) -> ArrayBase +fn clone_with_layout(l: MatrixLayout, a: &ArrayBase) -> ArrayBase where A: Copy, Si: Data, @@ -65,7 +65,7 @@ where b } -pub fn data_transpose(a: &mut ArrayBase) -> Result<&mut ArrayBase> +pub fn transpose_data(a: &mut ArrayBase) -> Result<&mut ArrayBase> where A: Copy, S: DataOwned + DataMut, diff --git a/src/qr.rs b/src/qr.rs index 9a526366..d5062603 100644 --- a/src/qr.rs +++ b/src/qr.rs @@ -68,7 +68,7 @@ where let k = ::std::cmp::min(n, m); let l = self.layout()?; let r = A::qr(l, self.as_allocated_mut()?)?; - let r: Array2<_> = reconstruct(l, r)?; + let r: Array2<_> = into_matrix(l, r)?; let q = self; Ok((take_slice(q, n, k), take_slice_upper(&r, k, m))) } @@ -87,7 +87,7 @@ impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a ArrayB let l = self.layout()?; let mut q = self.to_owned(); let r = A::qr(l, q.as_allocated_mut()?)?; - let r: Array2<_> = reconstruct(l, r)?; + let r: Array2<_> = into_matrix(l, r)?; Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m))) } } diff --git a/src/svd.rs b/src/svd.rs index d1170bb6..b0de84c6 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -59,9 +59,9 @@ where let l = self.layout()?; let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?; let (n, m) = l.size(); - let u = svd_res.u.map(|u| reconstruct(l.resized(n, n), u).unwrap()); + let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap()); let vt = svd_res.vt.map( - |vt| reconstruct(l.resized(m, m), vt).unwrap(), + |vt| into_matrix(l.resized(m, m), vt).unwrap(), ); let s = ArrayBase::from_vec(svd_res.s); Ok((u, s, vt)) diff --git a/src/triangular.rs b/src/triangular.rs index 9244a495..ef1bc34f 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -43,7 +43,7 @@ where let a_ = self.as_allocated()?; let lb = b.layout()?; if !la.same_order(&lb) { - data_transpose(b)?; + transpose_data(b)?; } let lb = b.layout()?; A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; From 57123151bde2c9a9caae954d6ee6cf7708d13014 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 28 Jun 2017 23:08:44 +0900 Subject: [PATCH 5/6] Add generalize --- src/convert.rs | 12 ++++++++++++ src/lib.rs | 1 + tests/convert.rs | 14 ++++++++++++++ 3 files changed, 27 insertions(+) create mode 100644 tests/convert.rs diff --git a/src/convert.rs b/src/convert.rs index 96e7d358..ca899a84 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -75,3 +75,15 @@ where ::std::mem::replace(a, new); Ok(a) } + +pub fn generalize(a: Array) -> ArrayBase +where + S: DataOwned, + D: Dimension, +{ + if a.is_standard_layout() { + ArrayBase::from_shape_vec(a.dim(), a.into_raw_vec()).unwrap() + } else { + ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec()).unwrap() + } +} diff --git a/src/lib.rs b/src/lib.rs index 7c48ec4f..c9e3c043 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,6 +74,7 @@ pub mod norm; pub mod trace; pub use assert::*; +pub use convert::*; pub use generate::*; pub use layout::*; pub use types::*; diff --git a/tests/convert.rs b/tests/convert.rs new file mode 100644 index 00000000..02382c97 --- /dev/null +++ b/tests/convert.rs @@ -0,0 +1,14 @@ + +extern crate ndarray; +extern crate ndarray_linalg; + +use ndarray::*; +use ndarray_linalg::*; + +#[test] +fn generalize() { + let a: Array3 = random((3, 2, 4).f()); + let ans = a.clone(); + let a: Array3 = convert::generalize(a); + assert_eq!(a, ans); +} From ee067b2ac9b858a2dd9f7b03958fa0e9cbd2e80b Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 28 Jun 2017 23:14:04 +0900 Subject: [PATCH 6/6] Add error check and link --- src/convert.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/convert.rs b/src/convert.rs index ca899a84..54109204 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -81,9 +81,18 @@ where S: DataOwned, D: Dimension, { - if a.is_standard_layout() { + // FIXME + // https://github.com/bluss/rust-ndarray/issues/325 + let strides: Vec = a.strides().to_vec(); + let new = if a.is_standard_layout() { ArrayBase::from_shape_vec(a.dim(), a.into_raw_vec()).unwrap() } else { ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec()).unwrap() - } + }; + assert_eq!( + new.strides(), + strides.as_slice(), + "Custom stride is not supported" + ); + new }