Skip to content

Commit 831353f

Browse files
committed
fn sample_efraimidis_spirakis: use BinaryHeap
1 parent a039a7f commit 831353f

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

src/seq/index.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ where
354354
N: UInt,
355355
IndexVec: From<Vec<N>>,
356356
{
357-
use std::cmp::Ordering;
357+
use std::{cmp::Ordering, collections::BinaryHeap};
358358

359359
if amount == N::zero() {
360360
return Ok(IndexVec::U32(Vec::new()));
@@ -373,9 +373,9 @@ where
373373

374374
impl<N> Ord for Element<N> {
375375
fn cmp(&self, other: &Self) -> Ordering {
376-
// partial_cmp will always produce a value,
377-
// because we check that the weights are not nan
378-
self.key.partial_cmp(&other.key).unwrap()
376+
// unwrap() should not panic since weights should not be NaN
377+
// We reverse so that BinaryHeap::peek shows the smallest item
378+
self.key.partial_cmp(&other.key).unwrap().reverse()
379379
}
380380
}
381381

@@ -387,7 +387,7 @@ where
387387

388388
impl<N> Eq for Element<N> {}
389389

390-
let mut candidates = Vec::with_capacity(amount.as_usize());
390+
let mut candidates = BinaryHeap::with_capacity(amount.as_usize());
391391
let mut index = N::zero();
392392
while index < length && candidates.len() < amount.as_usize() {
393393
let weight = weight(index.as_usize()).into();
@@ -402,26 +402,23 @@ where
402402

403403
index += N::one();
404404
}
405-
candidates.sort_unstable();
406405

407406
if candidates.len() < amount.as_usize() {
408407
return Err(WeightError::InsufficientNonZero);
409408
}
410409

411-
let mut x = rng.random::<f64>().ln() / candidates[0].key;
410+
let mut x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
412411
while index < length {
413412
let weight = weight(index.as_usize()).into();
414413
if weight > 0.0 {
415414
x -= weight;
416415
if x <= 0.0 {
417-
let t = (candidates[0].key * weight).exp();
416+
let min_candidate = candidates.pop().unwrap();
417+
let t = (min_candidate.key * weight).exp();
418418
let key = rng.random_range(t..1.0).ln() / weight;
419-
candidates[0] = Element { index, key };
420-
// TODO: consider using a binary tree instead of sorting at each
421-
// step. This should be faster for some THRESHOLD < amount.
422-
candidates.sort_unstable();
419+
candidates.push(Element { index, key });
423420

424-
x = rng.random::<f64>().ln() / candidates[0].key;
421+
x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
425422
}
426423
} else if !(weight >= 0.0) {
427424
return Err(WeightError::InvalidWeight);

0 commit comments

Comments
 (0)