Skip to content

Commit 928f3b8

Browse files
Merge pull request #1 from jturner314/bulk-quantiles
Various improvements to bunk quantiles
2 parents dc22814 + 355bb93 commit 928f3b8

File tree

3 files changed

+130
-99
lines changed

3 files changed

+130
-99
lines changed

src/quantile/interpolate.rs

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,46 @@ use ndarray::prelude::*;
44
use noisy_float::types::N64;
55
use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive};
66

7+
fn float_quantile_index(q: N64, len: usize) -> N64 {
8+
q * ((len - 1) as f64)
9+
}
10+
11+
/// Returns the fraction that the quantile is between the lower and higher indices.
12+
///
13+
/// This ranges from 0, where the quantile exactly corresponds the lower index,
14+
/// to 1, where the quantile exactly corresponds to the higher index.
15+
fn float_quantile_index_fraction(q: N64, len: usize) -> N64 {
16+
float_quantile_index(q, len).fract()
17+
}
18+
19+
/// Returns the index of the value on the lower side of the quantile.
20+
pub(crate) fn lower_index(q: N64, len: usize) -> usize {
21+
float_quantile_index(q, len).floor().to_usize().unwrap()
22+
}
23+
24+
/// Returns the index of the value on the higher side of the quantile.
25+
pub(crate) fn higher_index(q: N64, len: usize) -> usize {
26+
float_quantile_index(q, len).ceil().to_usize().unwrap()
27+
}
28+
729
/// Used to provide an interpolation strategy to [`quantile_axis_mut`].
830
///
931
/// [`quantile_axis_mut`]: ../trait.QuantileExt.html#tymethod.quantile_axis_mut
1032
pub trait Interpolate<T> {
11-
#[doc(hidden)]
12-
fn float_quantile_index(q: N64, len: usize) -> N64 {
13-
q * ((len - 1) as f64)
14-
}
15-
#[doc(hidden)]
16-
fn lower_index(q: N64, len: usize) -> usize {
17-
Self::float_quantile_index(q, len)
18-
.floor()
19-
.to_usize()
20-
.unwrap()
21-
}
22-
#[doc(hidden)]
23-
fn higher_index(q: N64, len: usize) -> usize {
24-
Self::float_quantile_index(q, len)
25-
.ceil()
26-
.to_usize()
27-
.unwrap()
28-
}
29-
#[doc(hidden)]
30-
fn float_quantile_index_fraction(q: N64, len: usize) -> N64 {
31-
Self::float_quantile_index(q, len).fract()
32-
}
33+
/// Returns `true` iff the lower value is needed to compute the
34+
/// interpolated value.
3335
#[doc(hidden)]
3436
fn needs_lower(q: N64, len: usize) -> bool;
37+
38+
/// Returns `true` iff the higher value is needed to compute the
39+
/// interpolated value.
3540
#[doc(hidden)]
3641
fn needs_higher(q: N64, len: usize) -> bool;
42+
43+
/// Computes the interpolated value.
44+
///
45+
/// **Panics** if `None` is provided for the lower value when it's needed
46+
/// or if `None` is provided for the higher value when it's needed.
3747
#[doc(hidden)]
3848
fn interpolate<D>(
3949
lower: Option<Array<T, D>>,
@@ -94,7 +104,7 @@ impl<T> Interpolate<T> for Lower {
94104

95105
impl<T> Interpolate<T> for Nearest {
96106
fn needs_lower(q: N64, len: usize) -> bool {
97-
<Self as Interpolate<T>>::float_quantile_index_fraction(q, len) < 0.5
107+
float_quantile_index_fraction(q, len) < 0.5
98108
}
99109
fn needs_higher(q: N64, len: usize) -> bool {
100110
!<Self as Interpolate<T>>::needs_lower(q, len)
@@ -163,9 +173,7 @@ where
163173
where
164174
D: Dimension,
165175
{
166-
let fraction = <Self as Interpolate<T>>::float_quantile_index_fraction(q, len)
167-
.to_f64()
168-
.unwrap();
176+
let fraction = float_quantile_index_fraction(q, len).to_f64().unwrap();
169177
let mut a = lower.unwrap();
170178
let b = higher.unwrap();
171179
azip!(mut a, ref b in {

src/quantile/mod.rs

Lines changed: 57 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use self::interpolate::Interpolate;
1+
use self::interpolate::{higher_index, lower_index, Interpolate};
22
use super::sort::get_many_from_sorted_mut_unchecked;
33
use indexmap::{IndexMap, IndexSet};
44
use ndarray::prelude::*;
@@ -204,57 +204,51 @@ where
204204

205205
let axis_len = self.len_of(axis);
206206
if axis_len == 0 {
207-
None
208-
} else {
209-
let mut deduped_qs: Vec<N64> = qs.to_vec();
210-
deduped_qs.sort_by(|a, b| a.partial_cmp(b).unwrap());
211-
deduped_qs.dedup();
207+
return None;
208+
}
209+
210+
let mut deduped_qs: Vec<N64> = qs.to_vec();
211+
deduped_qs.sort_by(|a, b| a.partial_cmp(b).unwrap());
212+
deduped_qs.dedup();
212213

213-
// IndexSet preserves insertion order:
214-
// - indexes will stay sorted;
215-
// - we avoid index duplication.
216-
let mut searched_indexes = IndexSet::new();
217-
for q in deduped_qs.iter() {
218-
if I::needs_lower(*q, axis_len) {
219-
searched_indexes.insert(I::lower_index(*q, axis_len));
220-
}
221-
if I::needs_higher(*q, axis_len) {
222-
searched_indexes.insert(I::higher_index(*q, axis_len));
223-
}
214+
// IndexSet preserves insertion order:
215+
// - indexes will stay sorted;
216+
// - we avoid index duplication.
217+
let mut searched_indexes = IndexSet::new();
218+
for q in deduped_qs.iter() {
219+
if I::needs_lower(*q, axis_len) {
220+
searched_indexes.insert(lower_index(*q, axis_len));
221+
}
222+
if I::needs_higher(*q, axis_len) {
223+
searched_indexes.insert(higher_index(*q, axis_len));
224224
}
225-
let searched_indexes: Vec<usize> = searched_indexes.into_iter().collect();
225+
}
226+
let searched_indexes: Vec<usize> = searched_indexes.into_iter().collect();
226227

227-
// Retrieve the values corresponding to each index for each slice along the specified axis
228-
let values = self.map_axis_mut(axis, |mut x| {
229-
get_many_from_sorted_mut_unchecked(&mut x, &searched_indexes)
230-
});
228+
// Retrieve the values corresponding to each index for each slice along the specified axis
229+
let values = self.map_axis_mut(
230+
axis,
231+
|mut x| get_many_from_sorted_mut_unchecked(&mut x, &searched_indexes)
232+
);
231233

232-
// Combine the retrieved values according to specified interpolation strategy to
233-
// get the desired quantiles
234-
let mut results = IndexMap::new();
235-
for q in qs {
236-
let result = I::interpolate(
237-
match I::needs_lower(*q, axis_len) {
238-
true => {
239-
let lower_index = &I::lower_index(*q, axis_len);
240-
Some(values.map(|x| x.get(lower_index).unwrap().clone()))
241-
}
242-
false => None,
243-
},
244-
match I::needs_higher(*q, axis_len) {
245-
true => {
246-
let higher_index = &I::higher_index(*q, axis_len);
247-
Some(values.map(|x| x.get(higher_index).unwrap().clone()))
248-
}
249-
false => None,
250-
},
251-
*q,
252-
axis_len,
253-
);
254-
results.insert(*q, result);
255-
}
256-
Some(results)
234+
// Combine the retrieved values according to specified interpolation strategy to
235+
// get the desired quantiles
236+
let mut results = IndexMap::new();
237+
for q in qs {
238+
let lower = if I::needs_lower(*q, axis_len) {
239+
Some(values.map(|x| x[&lower_index(*q, axis_len)].clone()))
240+
} else {
241+
None
242+
};
243+
let higher = if I::needs_higher(*q, axis_len) {
244+
Some(values.map(|x| x[&higher_index(*q, axis_len)].clone()))
245+
} else {
246+
None
247+
};
248+
let interpolated = I::interpolate(lower, higher, *q, axis_len);
249+
results.insert(*q, interpolated);
257250
}
251+
Some(results)
258252
}
259253

260254
fn quantile_axis_mut<I>(&mut self, axis: Axis, q: N64) -> Option<Array<A, D::Smaller>>
@@ -276,24 +270,23 @@ where
276270
S: DataMut,
277271
I: Interpolate<A::NotNan>,
278272
{
279-
if self.len_of(axis) > 0 {
280-
Some(self.map_axis_mut(axis, |lane| {
281-
let mut not_nan = A::remove_nan_mut(lane);
282-
A::from_not_nan_opt(if not_nan.is_empty() {
283-
None
284-
} else {
285-
Some(
286-
not_nan
287-
.quantile_axis_mut::<I>(Axis(0), q)
288-
.unwrap()
289-
.into_raw_vec()
290-
.remove(0),
291-
)
292-
})
293-
}))
294-
} else {
295-
None
273+
if self.len_of(axis) == 0 {
274+
return None;
296275
}
276+
let quantile = self.map_axis_mut(axis, |lane| {
277+
let mut not_nan = A::remove_nan_mut(lane);
278+
A::from_not_nan_opt(if not_nan.is_empty() {
279+
None
280+
} else {
281+
Some(
282+
not_nan
283+
.quantile_axis_mut::<I>(Axis(0), q)
284+
.unwrap()
285+
.into_scalar(),
286+
)
287+
})
288+
});
289+
Some(quantile)
297290
}
298291
}
299292

src/sort.rs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,16 @@ where
4949
A: Ord + Clone,
5050
S: DataMut;
5151

52-
/// Return the index of `self[partition_index]` if `self` were to be sorted
53-
/// in increasing order.
52+
/// Partitions the array in increasing order based on the value initially
53+
/// located at `pivot_index` and returns the new index of the value.
5454
///
55-
/// `self` elements are rearranged in such a way that `self[partition_index]`
56-
/// is in the position it would be in an array sorted in increasing order.
57-
/// All elements smaller than `self[partition_index]` are moved to its
58-
/// left and all elements equal or greater than `self[partition_index]`
59-
/// are moved to its right.
60-
/// The ordering of the elements in the two partitions is undefined.
55+
/// The elements are rearranged in such a way that the value initially
56+
/// located at `pivot_index` is moved to the position it would be in an
57+
/// array sorted in increasing order. The return value is the new index of
58+
/// the value after rearrangement. All elements smaller than the value are
59+
/// moved to its left and all elements equal or greater than the value are
60+
/// moved to its right. The ordering of the elements in the two partitions
61+
/// is undefined.
6162
///
6263
/// `self` is shuffled **in place** to operate the desired partition:
6364
/// no copy of the array is allocated.
@@ -67,7 +68,36 @@ where
6768
/// Average number of element swaps: n/6 - 1/3 (see
6869
/// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
6970
///
70-
/// **Panics** if `partition_index` is greater than or equal to `n`.
71+
/// **Panics** if `pivot_index` is greater than or equal to `n`.
72+
///
73+
/// # Example
74+
///
75+
/// ```
76+
/// extern crate ndarray;
77+
/// extern crate ndarray_stats;
78+
///
79+
/// use ndarray::array;
80+
/// use ndarray_stats::Sort1dExt;
81+
///
82+
/// # fn main() {
83+
/// let mut data = array![3, 1, 4, 5, 2];
84+
/// let pivot_index = 2;
85+
/// let pivot_value = data[pivot_index];
86+
///
87+
/// // Partition by the value located at `pivot_index`.
88+
/// let new_index = data.partition_mut(pivot_index);
89+
/// // The pivot value is now located at `new_index`.
90+
/// assert_eq!(data[new_index], pivot_value);
91+
/// // Elements less than that value are moved to the left.
92+
/// for i in 0..new_index {
93+
/// assert!(data[i] < pivot_value);
94+
/// }
95+
/// // Elements greater than or equal to that value are moved to the right.
96+
/// for i in (new_index + 1)..data.len() {
97+
/// assert!(data[i] >= pivot_value);
98+
/// }
99+
/// # }
100+
/// ```
71101
fn partition_mut(&mut self, pivot_index: usize) -> usize
72102
where
73103
A: Ord + Clone,

0 commit comments

Comments
 (0)