From 26899cb4f0d834b6119a4b2e53c2937a5e46414d Mon Sep 17 00:00:00 2001 From: Zheng Z Date: Sat, 22 Jan 2022 20:39:42 +0200 Subject: [PATCH] implement outer() for 1D array --- src/linalg/impl_linalg.rs | 15 +++++++++++++++ tests/array.rs | 7 +++++++ 2 files changed, 22 insertions(+) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 52a15f44e..f75de7c2a 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -137,6 +137,21 @@ where } self.dot_generic(rhs) } + + /// Outer product of two 1D arrays. + /// + /// The outer product of two vectors a (of dimension M) and b (of dimension N) + /// is defined as an (M*N)-dimensional matrix whose ij-th element is a_i * b_j. + /// This implementation essentially calls `dot` by reshaping the vectors. + pub fn outer(&self, b: &ArrayBase) -> Array + where + S2: Data, + A: LinalgScalar, + { + let (size_a, size_b) = (self.shape()[0], b.shape()[0]); + let b_reshaped = b.view().into_shape((1, size_b)).unwrap(); + self.view().into_shape((size_a, 1)).unwrap().dot(&b_reshaped) + } } /// Return a pointer to the starting element in BLAS's view. diff --git a/tests/array.rs b/tests/array.rs index e3922ea8d..818a0e136 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -80,6 +80,13 @@ fn test_mat_mul() { assert_eq!(c.dot(&a), a); } +#[test] +fn test_outer_product() { + let a: Array1 = array![2., 4.]; + let b: Array1 = array![3., 5.]; + assert_eq!(a.outer(&b), array![[6., 10.], [12., 20.]]); +} + #[deny(unsafe_code)] #[test] fn test_slice() {