@@ -133,6 +133,13 @@ extern "C" {
133
133
dim : c_int ,
134
134
nan_val : c_double ,
135
135
) -> c_int ;
136
+ fn af_max_ragged (
137
+ val_out : * mut af_array ,
138
+ idx_out : * mut af_array ,
139
+ input : af_array ,
140
+ ragged_len : af_array ,
141
+ dim : c_int ,
142
+ ) -> c_int ;
136
143
}
137
144
138
145
macro_rules! dim_reduce_func_def {
@@ -1440,6 +1447,66 @@ dim_reduce_by_key_nan_func_def!(
1440
1447
ValueType :: ProductOutType
1441
1448
) ;
1442
1449
1450
+ /// Max reduction along given axis as per ragged lengths provided
1451
+ ///
1452
+ /// # Parameters
1453
+ ///
1454
+ /// - `input` contains the input values to be reduced
1455
+ /// - `ragged_len` array containing number of elements to use when reducing along `dim`
1456
+ /// - `dim` is the dimension along which the max operation occurs
1457
+ ///
1458
+ /// # Return Values
1459
+ ///
1460
+ /// Tuple of Arrays:
1461
+ /// - First element: An Array containing the maximum ragged values in `input` along `dim`
1462
+ /// according to `ragged_len`
1463
+ /// - Second Element: An Array containing the locations of the maximum ragged values in
1464
+ /// `input` along `dim` according to `ragged_len`
1465
+ ///
1466
+ /// # Examples
1467
+ /// ```rust
1468
+ /// use arrayfire::{Array, dim4, print, randu, max_ragged};
1469
+ /// let vals: [f32; 6] = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1470
+ /// let rlens: [u32; 2] = [9, 2];
1471
+ /// let varr = Array::new(&vals, dim4![3, 2]);
1472
+ /// let rarr = Array::new(&rlens, dim4![1, 2]);
1473
+ /// print(&varr);
1474
+ /// // 1 4
1475
+ /// // 2 5
1476
+ /// // 3 6
1477
+ /// print(&rarr); // numbers of elements to participate in reduction along given axis
1478
+ /// // 9 2
1479
+ /// let (out, idx) = max_ragged(&varr, &rarr, 0);
1480
+ /// print(&out);
1481
+ /// // 3 5
1482
+ /// print(&idx);
1483
+ /// // 2 1 //Since 3 is max element for given length 9 along first column
1484
+ /// //Since 5 is max element for given length 2 along second column
1485
+ /// ```
1486
+ pub fn max_ragged < T > (
1487
+ input : & Array < T > ,
1488
+ ragged_len : & Array < u32 > ,
1489
+ dim : i32 ,
1490
+ ) -> ( Array < T :: InType > , Array < u32 > )
1491
+ where
1492
+ T : HasAfEnum ,
1493
+ T :: InType : HasAfEnum ,
1494
+ {
1495
+ unsafe {
1496
+ let mut out_vals: af_array = std:: ptr:: null_mut ( ) ;
1497
+ let mut out_idxs: af_array = std:: ptr:: null_mut ( ) ;
1498
+ let err_val = af_max_ragged (
1499
+ & mut out_vals as * mut af_array ,
1500
+ & mut out_idxs as * mut af_array ,
1501
+ input. get ( ) ,
1502
+ ragged_len. get ( ) ,
1503
+ dim,
1504
+ ) ;
1505
+ HANDLE_ERROR ( AfError :: from ( err_val) ) ;
1506
+ ( out_vals. into ( ) , out_idxs. into ( ) )
1507
+ }
1508
+ }
1509
+
1443
1510
#[ cfg( test) ]
1444
1511
mod tests {
1445
1512
use super :: super :: core:: c32;
0 commit comments