Skip to content

Commit e55dd18

Browse files
committed
Update crate with new APIs added for upstream 3.8 release
- max ragged API, docs and example - API for cov, var & stdev fns that accept bias enum - bitwise complement
1 parent 00067c8 commit e55dd18

File tree

4 files changed

+305
-28
lines changed

4 files changed

+305
-28
lines changed

src/algorithm/mod.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ extern "C" {
133133
dim: c_int,
134134
nan_val: c_double,
135135
) -> c_int;
136+
fn af_max_ragged(
137+
val_out: *mut af_array,
138+
idx_out: *mut af_array,
139+
input: af_array,
140+
ragged_len: af_array,
141+
dim: c_int,
142+
) -> c_int;
136143
}
137144

138145
macro_rules! dim_reduce_func_def {
@@ -1440,6 +1447,66 @@ dim_reduce_by_key_nan_func_def!(
14401447
ValueType::ProductOutType
14411448
);
14421449

1450+
/// Max reduction along given axis as per ragged lengths provided
1451+
///
1452+
/// # Parameters
1453+
///
1454+
/// - `input` contains the input values to be reduced
1455+
/// - `ragged_len` array containing number of elements to use when reducing along `dim`
1456+
/// - `dim` is the dimension along which the max operation occurs
1457+
///
1458+
/// # Return Values
1459+
///
1460+
/// Tuple of Arrays:
1461+
/// - First element: An Array containing the maximum ragged values in `input` along `dim`
1462+
/// according to `ragged_len`
1463+
/// - Second Element: An Array containing the locations of the maximum ragged values in
1464+
/// `input` along `dim` according to `ragged_len`
1465+
///
1466+
/// # Examples
1467+
/// ```rust
1468+
/// use arrayfire::{Array, dim4, print, randu, max_ragged};
1469+
/// let vals: [f32; 6] = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1470+
/// let rlens: [u32; 2] = [9, 2];
1471+
/// let varr = Array::new(&vals, dim4![3, 2]);
1472+
/// let rarr = Array::new(&rlens, dim4![1, 2]);
1473+
/// print(&varr);
1474+
/// // 1 4
1475+
/// // 2 5
1476+
/// // 3 6
1477+
/// print(&rarr); // numbers of elements to participate in reduction along given axis
1478+
/// // 9 2
1479+
/// let (out, idx) = max_ragged(&varr, &rarr, 0);
1480+
/// print(&out);
1481+
/// // 3 5
1482+
/// print(&idx);
1483+
/// // 2 1 //Since 3 is max element for given length 9 along first column
1484+
/// //Since 5 is max element for given length 2 along second column
1485+
/// ```
1486+
pub fn max_ragged<T>(
1487+
input: &Array<T>,
1488+
ragged_len: &Array<u32>,
1489+
dim: i32,
1490+
) -> (Array<T::InType>, Array<u32>)
1491+
where
1492+
T: HasAfEnum,
1493+
T::InType: HasAfEnum,
1494+
{
1495+
unsafe {
1496+
let mut out_vals: af_array = std::ptr::null_mut();
1497+
let mut out_idxs: af_array = std::ptr::null_mut();
1498+
let err_val = af_max_ragged(
1499+
&mut out_vals as *mut af_array,
1500+
&mut out_idxs as *mut af_array,
1501+
input.get(),
1502+
ragged_len.get(),
1503+
dim,
1504+
);
1505+
HANDLE_ERROR(AfError::from(err_val));
1506+
(out_vals.into(), out_idxs.into())
1507+
}
1508+
}
1509+
14431510
#[cfg(test)]
14441511
mod tests {
14451512
use super::super::core::c32;

src/core/arith.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use super::data::{constant, tile, ConstGenerator};
33
use super::defines::AfError;
44
use super::dim4::Dim4;
55
use super::error::HANDLE_ERROR;
6-
use super::util::{af_array, HasAfEnum, ImplicitPromote};
6+
use super::util::{af_array, HasAfEnum, ImplicitPromote, IntegralType};
77
use num::Zero;
88

99
use libc::c_int;
@@ -97,6 +97,7 @@ extern "C" {
9797
fn af_iszero(out: *mut af_array, arr: af_array) -> c_int;
9898
fn af_isinf(out: *mut af_array, arr: af_array) -> c_int;
9999
fn af_isnan(out: *mut af_array, arr: af_array) -> c_int;
100+
fn af_bitnot(out: *mut af_array, arr: af_array) -> c_int;
100101
}
101102

102103
/// Enables use of `!` on objects of type [Array](./struct.Array.html)
@@ -1008,3 +1009,16 @@ where
10081009
sub(&cnst, &self, true)
10091010
}
10101011
}
1012+
1013+
/// Perform bitwise complement on all values of Array
1014+
pub fn bitnot<T: HasAfEnum>(input: &Array<T>) -> Array<T>
1015+
where
1016+
T: HasAfEnum + IntegralType,
1017+
{
1018+
unsafe {
1019+
let mut temp: af_array = std::ptr::null_mut();
1020+
let err_val = af_bitnot(&mut temp as *mut af_array, input.get());
1021+
HANDLE_ERROR(AfError::from(err_val));
1022+
temp.into()
1023+
}
1024+
}

src/core/util.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,3 +827,15 @@ impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }}
827827
impl Fromf64 for u8 { fn fromf64(value: f64) -> Self { value as Self }}
828828
#[rustfmt::skip]
829829
impl Fromf64 for bool { fn fromf64(value: f64) -> Self { value > 0.0 }}
830+
831+
/// Trait qualifier for given type indicating computability of covariance
832+
pub trait IntegralType {}
833+
834+
impl IntegralType for i64 {}
835+
impl IntegralType for u64 {}
836+
impl IntegralType for i32 {}
837+
impl IntegralType for u32 {}
838+
impl IntegralType for i16 {}
839+
impl IntegralType for u16 {}
840+
impl IntegralType for u8 {}
841+
impl IntegralType for bool {}

0 commit comments

Comments
 (0)