From c78d97ce936905cef10bcae205208147f05ebb76 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 7 Oct 2020 14:10:28 +0530 Subject: [PATCH] Fix return type trait bound on reduce all functions --- src/algorithm/mod.rs | 140 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 25 deletions(-) diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index 35d286404..3d23530c4 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -518,12 +518,17 @@ where } macro_rules! all_reduce_func_def { - ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => { + ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) -> ($out_type, $out_type) + pub fn $fn_name(input: &Array) + -> ( + <::$assoc_type as HasAfEnum>::BaseType, + <::$assoc_type as HasAfEnum>::BaseType + ) where T: HasAfEnum, - $out_type: HasAfEnum + Fromf64 + ::$assoc_type: HasAfEnum, + <::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64, { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; @@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def { ); HANDLE_ERROR(AfError::from(err_val)); } - (<$out_type>::fromf64(real), <$out_type>::fromf64(imag)) + ( + <::$assoc_type as HasAfEnum>::BaseType::fromf64(real), + <::$assoc_type as HasAfEnum>::BaseType::fromf64(imag), + ) } }; } @@ -564,7 +572,7 @@ all_reduce_func_def!( ", sum_all, af_sum_all, - T::AggregateOutType + AggregateOutType ); all_reduce_func_def!( @@ -594,7 +602,7 @@ all_reduce_func_def!( ", product_all, af_product_all, - T::ProductOutType + ProductOutType ); all_reduce_func_def!( @@ -623,7 +631,7 @@ all_reduce_func_def!( ", min_all, af_min_all, - T::InType + InType ); all_reduce_func_def!( @@ -652,10 +660,31 @@ all_reduce_func_def!( ", max_all, af_max_all, - T::InType + InType ); -all_reduce_func_def!( +macro_rules! all_reduce_func_def2 { + ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => { + #[doc=$doc_str] + pub fn $fn_name(input: &Array) -> ($out_type, $out_type) + where + T: HasAfEnum, + $out_type: HasAfEnum + Fromf64 + { + let mut real: f64 = 0.0; + let mut imag: f64 = 0.0; + unsafe { + let err_val = $ffi_name( + &mut real as *mut c_double, &mut imag as *mut c_double, input.get(), + ); + HANDLE_ERROR(AfError::from(err_val)); + } + (<$out_type>::fromf64(real), <$out_type>::fromf64(imag)) + } + }; +} + +all_reduce_func_def2!( " Find if all values of Array are non-zero @@ -682,7 +711,7 @@ all_reduce_func_def!( bool ); -all_reduce_func_def!( +all_reduce_func_def2!( " Find if any value of Array is non-zero @@ -709,7 +738,7 @@ all_reduce_func_def!( bool ); -all_reduce_func_def!( +all_reduce_func_def2!( " Count number of non-zero values in the Array @@ -751,10 +780,17 @@ all_reduce_func_def!( /// A tuple of summation result. /// /// Note: For non-complex data type Arrays, second value of tuple is zero. -pub fn sum_nan_all(input: &Array, val: f64) -> (T::AggregateOutType, T::AggregateOutType) +pub fn sum_nan_all( + input: &Array, + val: f64, +) -> ( + <::AggregateOutType as HasAfEnum>::BaseType, + <::AggregateOutType as HasAfEnum>::BaseType, +) where T: HasAfEnum, - T::AggregateOutType: HasAfEnum + Fromf64, + ::AggregateOutType: HasAfEnum, + <::AggregateOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64, { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; @@ -768,8 +804,8 @@ where HANDLE_ERROR(AfError::from(err_val)); } ( - ::fromf64(real), - ::fromf64(imag), + <::AggregateOutType as HasAfEnum>::BaseType::fromf64(real), + <::AggregateOutType as HasAfEnum>::BaseType::fromf64(imag), ) } @@ -788,10 +824,17 @@ where /// A tuple of product result. /// /// Note: For non-complex data type Arrays, second value of tuple is zero. -pub fn product_nan_all(input: &Array, val: f64) -> (T::ProductOutType, T::ProductOutType) +pub fn product_nan_all( + input: &Array, + val: f64, +) -> ( + <::ProductOutType as HasAfEnum>::BaseType, + <::ProductOutType as HasAfEnum>::BaseType, +) where T: HasAfEnum, - T::ProductOutType: HasAfEnum + Fromf64, + ::ProductOutType: HasAfEnum, + <::ProductOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64, { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; @@ -805,8 +848,8 @@ where HANDLE_ERROR(AfError::from(err_val)); } ( - ::fromf64(real), - ::fromf64(imag), + <::ProductOutType as HasAfEnum>::BaseType::fromf64(real), + <::ProductOutType as HasAfEnum>::BaseType::fromf64(imag), ) } @@ -858,12 +901,18 @@ dim_ireduce_func_def!(" ", imax, af_imax, InType); macro_rules! all_ireduce_func_def { - ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => { + ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) -> ($out_type, $out_type, u32) + pub fn $fn_name(input: &Array) + -> ( + <::$assoc_type as HasAfEnum>::BaseType, + <::$assoc_type as HasAfEnum>::BaseType, + u32 + ) where T: HasAfEnum, - $out_type: HasAfEnum + Fromf64 + ::$assoc_type: HasAfEnum, + <::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64, { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; @@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def { ); HANDLE_ERROR(AfError::from(err_val)); } - (<$out_type>::fromf64(real), <$out_type>::fromf64(imag), temp) + ( + <::$assoc_type as HasAfEnum>::BaseType::fromf64(real), + <::$assoc_type as HasAfEnum>::BaseType::fromf64(imag), + temp, + ) } }; } @@ -898,7 +951,7 @@ all_ireduce_func_def!( ", imin_all, af_imin_all, - T::InType + InType ); all_ireduce_func_def!( " @@ -918,7 +971,7 @@ all_ireduce_func_def!( ", imax_all, af_imax_all, - T::InType + InType ); /// Locate the indices of non-zero elements. @@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!( af_product_by_key_nan, ValueType::ProductOutType ); + +#[cfg(test)] +mod tests { + use super::super::core::c32; + use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all}; + use crate::randu; + + #[test] + fn all_reduce_api() { + let a = randu!(c32; 10, 10); + println!("Reduction of complex f32 matrix: {:?}", sum_all(&a)); + + let b = randu!(bool; 10, 10); + println!("reduction of bool matrix: {:?}", sum_all(&b)); + + println!( + "reduction of complex f32 matrix after replacing nan with {}: {:?}", + 1.0, + product_nan_all(&a, 1.0) + ); + + println!( + "reduction of bool matrix after replacing nan with {}: {:?}", + 0.0, + sum_nan_all(&b, 0.0) + ); + } + + #[test] + fn all_ireduce_api() { + let a = randu!(c32; 10); + println!("Reduction of complex f32 matrix: {:?}", imin_all(&a)); + + let b = randu!(u32; 10); + println!("reduction of bool matrix: {:?}", imax_all(&b)); + } +}