@@ -354,7 +354,7 @@ where
354
354
N : UInt ,
355
355
IndexVec : From < Vec < N > > ,
356
356
{
357
- use std:: cmp:: Ordering ;
357
+ use std:: { cmp:: Ordering , collections :: BinaryHeap } ;
358
358
359
359
if amount == N :: zero ( ) {
360
360
return Ok ( IndexVec :: U32 ( Vec :: new ( ) ) ) ;
@@ -373,9 +373,9 @@ where
373
373
374
374
impl < N > Ord for Element < N > {
375
375
fn cmp ( & self , other : & Self ) -> Ordering {
376
- // partial_cmp will always produce a value,
377
- // because we check that the weights are not nan
378
- self . key . partial_cmp ( & other. key ) . unwrap ( )
376
+ // unwrap() should not panic since weights should not be NaN
377
+ // We reverse so that BinaryHeap::peek shows the smallest item
378
+ self . key . partial_cmp ( & other. key ) . unwrap ( ) . reverse ( )
379
379
}
380
380
}
381
381
@@ -387,7 +387,7 @@ where
387
387
388
388
impl < N > Eq for Element < N > { }
389
389
390
- let mut candidates = Vec :: with_capacity ( amount. as_usize ( ) ) ;
390
+ let mut candidates = BinaryHeap :: with_capacity ( amount. as_usize ( ) ) ;
391
391
let mut index = N :: zero ( ) ;
392
392
while index < length && candidates. len ( ) < amount. as_usize ( ) {
393
393
let weight = weight ( index. as_usize ( ) ) . into ( ) ;
@@ -402,26 +402,23 @@ where
402
402
403
403
index += N :: one ( ) ;
404
404
}
405
- candidates. sort_unstable ( ) ;
406
405
407
406
if candidates. len ( ) < amount. as_usize ( ) {
408
407
return Err ( WeightError :: InsufficientNonZero ) ;
409
408
}
410
409
411
- let mut x = rng. random :: < f64 > ( ) . ln ( ) / candidates[ 0 ] . key ;
410
+ let mut x = rng. random :: < f64 > ( ) . ln ( ) / candidates. peek ( ) . unwrap ( ) . key ;
412
411
while index < length {
413
412
let weight = weight ( index. as_usize ( ) ) . into ( ) ;
414
413
if weight > 0.0 {
415
414
x -= weight;
416
415
if x <= 0.0 {
417
- let t = ( candidates[ 0 ] . key * weight) . exp ( ) ;
416
+ let min_candidate = candidates. pop ( ) . unwrap ( ) ;
417
+ let t = ( min_candidate. key * weight) . exp ( ) ;
418
418
let key = rng. random_range ( t..1.0 ) . ln ( ) / weight;
419
- candidates[ 0 ] = Element { index, key } ;
420
- // TODO: consider using a binary tree instead of sorting at each
421
- // step. This should be faster for some THRESHOLD < amount.
422
- candidates. sort_unstable ( ) ;
419
+ candidates. push ( Element { index, key } ) ;
423
420
424
- x = rng. random :: < f64 > ( ) . ln ( ) / candidates[ 0 ] . key ;
421
+ x = rng. random :: < f64 > ( ) . ln ( ) / candidates. peek ( ) . unwrap ( ) . key ;
425
422
}
426
423
} else if !( weight >= 0.0 ) {
427
424
return Err ( WeightError :: InvalidWeight ) ;
0 commit comments