From 2ccf1684c8529466ff142284bd7c72a53a75f2c2 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 23 Jan 2019 13:44:37 +0530 Subject: [PATCH] Fix comparison functions output type --- examples/conway.rs | 2 +- src/arith/mod.rs | 127 ++++++++++++++++++++++++++++++++------------- 2 files changed, 93 insertions(+), 36 deletions(-) diff --git a/examples/conway.rs b/examples/conway.rs index c87f875d0..75b480d0e 100644 --- a/examples/conway.rs +++ b/examples/conway.rs @@ -24,6 +24,6 @@ fn conways_game_of_life() { let c0 = &eq(&n_hood, &c0, false); let c1 = &eq(&n_hood, &c1, false); state = state * c0 + c1; - win.draw_image(&normalise(&state), None); + win.draw_image(&normalise(&state.cast::()), None); } } diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 8d7eebc6d..ffd4db08a 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -433,17 +433,6 @@ macro_rules! overloaded_binary_func { /// /// An Array with results of the binary operation. /// - /// In the case of comparison operations such as the following, the type of output - /// Array is [DType::B8](./enum.DType.html). To retrieve the results of such boolean output - /// to host, an array of 8-bit wide types(eg. u8, i8) should be used since ArrayFire's internal - /// implementation uses char for boolean. - /// - /// * [gt](./fn.gt.html) - /// * [lt](./fn.lt.html) - /// * [ge](./fn.ge.html) - /// * [le](./fn.le.html) - /// * [eq](./fn.eq.html) - /// ///# Note /// /// The trait `Convertable` essentially translates to a scalar native type on rust or Array. @@ -487,55 +476,123 @@ overloaded_binary_func!("Compute remainder from two Arrays", rem, rem_helper, af overloaded_binary_func!("Compute left shift", shiftl, shiftl_helper, af_bitshiftl); overloaded_binary_func!("Compute right shift", shiftr, shiftr_helper, af_bitshiftr); overloaded_binary_func!( + "Compute modulo of two Arrays", + modulo, + modulo_helper, + af_mod +); +overloaded_binary_func!( + "Calculate atan2 of two Arrays", + atan2, + atan2_helper, + af_atan2 +); +overloaded_binary_func!( + "Create complex array from two Arrays", + cplx2, + cplx2_helper, + af_cplx2 +); +overloaded_binary_func!("Compute root", root, root_helper, af_root); +overloaded_binary_func!("Computer power", pow, pow_helper, af_pow); + +macro_rules! overloaded_compare_func { + ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => { + fn $help_name(lhs: &Array, rhs: &Array, batch: bool) -> Array + where + A: HasAfEnum + ImplicitPromote, + B: HasAfEnum + ImplicitPromote, + { + let mut temp: i64 = 0; + unsafe { + let err_val = $ffi_name( + &mut temp as MutAfArray, + lhs.get() as AfArray, + rhs.get() as AfArray, + batch as c_int, + ); + HANDLE_ERROR(AfError::from(err_val)); + } + temp.into() + } + + #[doc=$doc_str] + /// + /// This is a comparison operation. + /// + ///# Parameters + /// + /// - `arg1`is an argument that implements an internal trait `Convertable`. + /// - `arg2`is an argument that implements an internal trait `Convertable`. + /// - `batch` is an boolean that indicates if the current operation is an batch operation. + /// + /// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral + /// type. + /// + ///# Return Values + /// + /// An Array with results of the comparison operation a.k.a an Array of boolean values. + ///# Note + /// + /// The trait `Convertable` essentially translates to a scalar native type on rust or Array. + pub fn $fn_name( + arg1: &T, + arg2: &U, + batch: bool, + ) -> Array + where + T: Convertable, + U: Convertable, + ::OutType: HasAfEnum + ImplicitPromote<::OutType>, + ::OutType: HasAfEnum + ImplicitPromote<::OutType>, + { + let lhs = arg1.convert(); // Convert to Array + let rhs = arg2.convert(); // Convert to Array + match (lhs.is_scalar(), rhs.is_scalar()) { + (true, false) => { + let l = tile(&lhs, rhs.dims()); + $help_name(&l, &rhs, batch) + } + (false, true) => { + let r = tile(&rhs, lhs.dims()); + $help_name(&lhs, &r, batch) + } + _ => $help_name(&lhs, &rhs, batch), + } + } + }; +} + +overloaded_compare_func!( "Perform `less than` comparison operation", lt, lt_helper, af_lt ); -overloaded_binary_func!( +overloaded_compare_func!( "Perform `greater than` comparison operation", gt, gt_helper, af_gt ); -overloaded_binary_func!( +overloaded_compare_func!( "Perform `less than equals` comparison operation", le, le_helper, af_le ); -overloaded_binary_func!( +overloaded_compare_func!( "Perform `greater than equals` comparison operation", ge, ge_helper, af_ge ); -overloaded_binary_func!( +overloaded_compare_func!( "Perform `equals` comparison operation", eq, eq_helper, af_eq ); -overloaded_binary_func!( - "Compute modulo of two Arrays", - modulo, - modulo_helper, - af_mod -); -overloaded_binary_func!( - "Calculate atan2 of two Arrays", - atan2, - atan2_helper, - af_atan2 -); -overloaded_binary_func!( - "Create complex array from two Arrays", - cplx2, - cplx2_helper, - af_cplx2 -); -overloaded_binary_func!("Compute root", root, root_helper, af_root); -overloaded_binary_func!("Computer power", pow, pow_helper, af_pow); fn clamp_helper( inp: &Array,