diff --git a/src/arith/mod.rs b/src/arith/mod.rs index a6c867636..78baea368 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -671,6 +671,7 @@ where macro_rules! arith_scalar_func { ($rust_type: ty, $op_name:ident, $fn_name: ident) => { + // Implement (&Array op_name rust_type) impl<'f, T> $op_name<$rust_type> for &'f Array where T: HasAfEnum + ImplicitPromote<$rust_type>, @@ -685,6 +686,7 @@ macro_rules! arith_scalar_func { } } + // Implement (Array op_name rust_type) impl $op_name<$rust_type> for Array where T: HasAfEnum + ImplicitPromote<$rust_type>, @@ -698,6 +700,34 @@ macro_rules! arith_scalar_func { $fn_name(&self, &temp, false) } } + + // Implement (rust_type op_name &Array) + impl<'f, T> $op_name<&'f Array> for $rust_type + where + T: HasAfEnum + ImplicitPromote<$rust_type>, + $rust_type: HasAfEnum + ImplicitPromote, + <$rust_type as ImplicitPromote>::Output: HasAfEnum, + { + type Output = Array<<$rust_type as ImplicitPromote>::Output>; + + fn $fn_name(self, rhs: &'f Array) -> Self::Output { + $fn_name(&self, rhs, false) + } + } + + // Implement (rust_type op_name Array) + impl $op_name> for $rust_type + where + T: HasAfEnum + ImplicitPromote<$rust_type>, + $rust_type: HasAfEnum + ImplicitPromote, + <$rust_type as ImplicitPromote>::Output: HasAfEnum, + { + type Output = Array<<$rust_type as ImplicitPromote>::Output>; + + fn $fn_name(self, rhs: Array) -> Self::Output { + $fn_name(&self, &rhs, false) + } + } }; } diff --git a/tests/lib.rs b/tests/error_handler.rs similarity index 100% rename from tests/lib.rs rename to tests/error_handler.rs diff --git a/tests/scalar_arith.rs b/tests/scalar_arith.rs new file mode 100644 index 000000000..cd9903ee6 --- /dev/null +++ b/tests/scalar_arith.rs @@ -0,0 +1,19 @@ +use ::arrayfire::*; + +#[allow(non_snake_case)] +#[test] +fn check_scalar_arith() { + let dims = Dim4::new(&[5, 5, 1, 1]); + let A = randu::(dims); + let s: f32 = 2.0; + let scalar_as_lhs = s * &A; + let scalar_as_rhs = &A * s; + let C = constant(s, dims); + let no_scalars = A * C; + let scalar_res_comp = eq(&scalar_as_lhs, &scalar_as_rhs, false); + let res_comp = eq(&scalar_as_lhs, &no_scalars, false); + let scalar_res = all_true_all(&scalar_res_comp); + let res = all_true_all(&res_comp); + + assert_eq!(scalar_res.0, res.0); +}