Skip to content

Commit dba696e

Browse files
authored
Rename WeightedError → WeightError; add IndexedRandom, IndexedMutRandom (#1382)
* Remove deprecated module rand::distributions::weighted * WeightedTree: return InvalidWeight on not-a-number * WeightedTree::try_sample return AllWeightsZero given no weights * Rename WeightedError -> WeightError and revise variants * Re-export WeightError from rand::seq * Revise errors of rand::index::sample_weighted * Split SliceRandom into IndexedRandom, IndexedMutRandom and SliceRandom
1 parent ef245fd commit dba696e

File tree

11 files changed

+329
-396
lines changed

11 files changed

+329
-396
lines changed

examples/monty-hall.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fn simulate<R: Rng>(random_door: &Uniform<u32>, rng: &mut R) -> SimulationResult
6161
// Returns the door the game host opens given our choice and knowledge of
6262
// where the car is. The game host will never open the door with the car.
6363
fn game_host_open<R: Rng>(car: u32, choice: u32, rng: &mut R) -> u32 {
64-
use rand::seq::SliceRandom;
64+
use rand::seq::IndexedRandom;
6565
*free_doors(&[car, choice]).choose(rng).unwrap()
6666
}
6767

rand_distr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ pub use self::weibull::{Error as WeibullError, Weibull};
130130
pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError};
131131
#[cfg(feature = "alloc")]
132132
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
133-
pub use rand::distributions::{WeightedError, WeightedIndex};
133+
pub use rand::distributions::{WeightError, WeightedIndex};
134134
#[cfg(feature = "alloc")]
135135
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
136136
pub use weighted_alias::WeightedAliasIndex;

rand_distr/src/weighted_alias.rs

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//! This module contains an implementation of alias method for sampling random
1010
//! indices with probabilities proportional to a collection of weights.
1111
12-
use super::WeightedError;
12+
use super::WeightError;
1313
use crate::{uniform::SampleUniform, Distribution, Uniform};
1414
use core::fmt;
1515
use core::iter::Sum;
@@ -79,18 +79,15 @@ pub struct WeightedAliasIndex<W: AliasableWeight> {
7979
impl<W: AliasableWeight> WeightedAliasIndex<W> {
8080
/// Creates a new [`WeightedAliasIndex`].
8181
///
82-
/// Returns an error if:
83-
/// - The vector is empty.
84-
/// - The vector is longer than `u32::MAX`.
85-
/// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
86-
/// weights.len()`.
87-
/// - The sum of weights is zero.
88-
pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
82+
/// Error cases:
83+
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
84+
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
85+
/// negative or greater than `max = W::MAX / weights.len()`.
86+
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
87+
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
8988
let n = weights.len();
90-
if n == 0 {
91-
return Err(WeightedError::NoItem);
92-
} else if n > ::core::u32::MAX as usize {
93-
return Err(WeightedError::TooMany);
89+
if n == 0 || n > ::core::u32::MAX as usize {
90+
return Err(WeightError::InvalidInput);
9491
}
9592
let n = n as u32;
9693

@@ -101,7 +98,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
10198
.iter()
10299
.all(|&w| W::ZERO <= w && w <= max_weight_size)
103100
{
104-
return Err(WeightedError::InvalidWeight);
101+
return Err(WeightError::InvalidWeight);
105102
}
106103

107104
// The sum of weights will represent 100% of no alias odds.
@@ -113,7 +110,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
113110
weight_sum
114111
};
115112
if weight_sum == W::ZERO {
116-
return Err(WeightedError::AllWeightsZero);
113+
return Err(WeightError::InsufficientNonZero);
117114
}
118115

119116
// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
@@ -382,23 +379,23 @@ mod test {
382379
// Floating point special cases
383380
assert_eq!(
384381
WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
385-
WeightedError::InvalidWeight
382+
WeightError::InvalidWeight
386383
);
387384
assert_eq!(
388385
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
389-
WeightedError::AllWeightsZero
386+
WeightError::InsufficientNonZero
390387
);
391388
assert_eq!(
392389
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
393-
WeightedError::InvalidWeight
390+
WeightError::InvalidWeight
394391
);
395392
assert_eq!(
396393
WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
397-
WeightedError::InvalidWeight
394+
WeightError::InvalidWeight
398395
);
399396
assert_eq!(
400397
WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
401-
WeightedError::InvalidWeight
398+
WeightError::InvalidWeight
402399
);
403400
}
404401

@@ -416,11 +413,11 @@ mod test {
416413
// Signed integer special cases
417414
assert_eq!(
418415
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
419-
WeightedError::InvalidWeight
416+
WeightError::InvalidWeight
420417
);
421418
assert_eq!(
422419
WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
423-
WeightedError::InvalidWeight
420+
WeightError::InvalidWeight
424421
);
425422
}
426423

@@ -438,11 +435,11 @@ mod test {
438435
// Signed integer special cases
439436
assert_eq!(
440437
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
441-
WeightedError::InvalidWeight
438+
WeightError::InvalidWeight
442439
);
443440
assert_eq!(
444441
WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
445-
WeightedError::InvalidWeight
442+
WeightError::InvalidWeight
446443
);
447444
}
448445

@@ -486,15 +483,15 @@ mod test {
486483

487484
assert_eq!(
488485
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
489-
WeightedError::NoItem
486+
WeightError::InvalidInput
490487
);
491488
assert_eq!(
492489
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
493-
WeightedError::AllWeightsZero
490+
WeightError::InsufficientNonZero
494491
);
495492
assert_eq!(
496493
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
497-
WeightedError::InvalidWeight
494+
WeightError::InvalidWeight
498495
);
499496
}
500497

rand_distr/src/weighted_tree.rs

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
1212
use core::ops::SubAssign;
1313

14-
use super::WeightedError;
14+
use super::WeightError;
1515
use crate::Distribution;
1616
use alloc::vec::Vec;
1717
use rand::distributions::uniform::{SampleBorrow, SampleUniform};
@@ -98,15 +98,19 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
9898
WeightedTreeIndex<W>
9999
{
100100
/// Creates a new [`WeightedTreeIndex`] from a slice of weights.
101-
pub fn new<I>(weights: I) -> Result<Self, WeightedError>
101+
///
102+
/// Error cases:
103+
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
104+
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
105+
pub fn new<I>(weights: I) -> Result<Self, WeightError>
102106
where
103107
I: IntoIterator,
104108
I::Item: SampleBorrow<W>,
105109
{
106110
let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
107111
for weight in subtotals.iter() {
108-
if *weight < W::ZERO {
109-
return Err(WeightedError::InvalidWeight);
112+
if !(*weight >= W::ZERO) {
113+
return Err(WeightError::InvalidWeight);
110114
}
111115
}
112116
let n = subtotals.len();
@@ -115,7 +119,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
115119
let parent = (i - 1) / 2;
116120
subtotals[parent]
117121
.checked_add_assign(&w)
118-
.map_err(|()| WeightedError::Overflow)?;
122+
.map_err(|()| WeightError::Overflow)?;
119123
}
120124
Ok(Self { subtotals })
121125
}
@@ -164,14 +168,18 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
164168
}
165169

166170
/// Appends a new weight at the end.
167-
pub fn push(&mut self, weight: W) -> Result<(), WeightedError> {
168-
if weight < W::ZERO {
169-
return Err(WeightedError::InvalidWeight);
171+
///
172+
/// Error cases:
173+
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
174+
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
175+
pub fn push(&mut self, weight: W) -> Result<(), WeightError> {
176+
if !(weight >= W::ZERO) {
177+
return Err(WeightError::InvalidWeight);
170178
}
171179
if let Some(total) = self.subtotals.first() {
172180
let mut total = total.clone();
173181
if total.checked_add_assign(&weight).is_err() {
174-
return Err(WeightedError::Overflow);
182+
return Err(WeightError::Overflow);
175183
}
176184
}
177185
let mut index = self.len();
@@ -184,9 +192,13 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
184192
}
185193

186194
/// Updates the weight at an index.
187-
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> {
188-
if weight < W::ZERO {
189-
return Err(WeightedError::InvalidWeight);
195+
///
196+
/// Error cases:
197+
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
198+
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
199+
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightError> {
200+
if !(weight >= W::ZERO) {
201+
return Err(WeightError::InvalidWeight);
190202
}
191203
let old_weight = self.get(index);
192204
if weight > old_weight {
@@ -195,7 +207,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
195207
if let Some(total) = self.subtotals.first() {
196208
let mut total = total.clone();
197209
if total.checked_add_assign(&difference).is_err() {
198-
return Err(WeightedError::Overflow);
210+
return Err(WeightError::Overflow);
199211
}
200212
}
201213
self.subtotals[index]
@@ -235,13 +247,10 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
235247
///
236248
/// Returns an error if there are no elements or all weights are zero. This
237249
/// is unlike [`Distribution::sample`], which panics in those cases.
238-
fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
239-
if self.subtotals.is_empty() {
240-
return Err(WeightedError::NoItem);
241-
}
242-
let total_weight = self.subtotals[0].clone();
250+
fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightError> {
251+
let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
243252
if total_weight == W::ZERO {
244-
return Err(WeightedError::AllWeightsZero);
253+
return Err(WeightError::InsufficientNonZero);
245254
}
246255
let mut target_weight = rng.gen_range(W::ZERO..total_weight);
247256
let mut index = 0;
@@ -296,19 +305,19 @@ mod test {
296305
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
297306
assert_eq!(
298307
tree.try_sample(&mut rng).unwrap_err(),
299-
WeightedError::NoItem
308+
WeightError::InsufficientNonZero
300309
);
301310
}
302311

303312
#[test]
304313
fn test_overflow_error() {
305314
assert_eq!(
306315
WeightedTreeIndex::new(&[i32::MAX, 2]),
307-
Err(WeightedError::Overflow)
316+
Err(WeightError::Overflow)
308317
);
309318
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
310-
assert_eq!(tree.push(3), Err(WeightedError::Overflow));
311-
assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow));
319+
assert_eq!(tree.push(3), Err(WeightError::Overflow));
320+
assert_eq!(tree.update(1, 4), Err(WeightError::Overflow));
312321
tree.update(1, 2).unwrap();
313322
}
314323

@@ -318,22 +327,22 @@ mod test {
318327
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
319328
assert_eq!(
320329
tree.try_sample(&mut rng).unwrap_err(),
321-
WeightedError::AllWeightsZero
330+
WeightError::InsufficientNonZero
322331
);
323332
}
324333

325334
#[test]
326335
fn test_invalid_weight_error() {
327336
assert_eq!(
328337
WeightedTreeIndex::<i32>::new(&[1, -1]).unwrap_err(),
329-
WeightedError::InvalidWeight
338+
WeightError::InvalidWeight
330339
);
331340
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
332-
assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight);
341+
assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight);
333342
tree.push(1).unwrap();
334343
assert_eq!(
335344
tree.update(0, -1).unwrap_err(),
336-
WeightedError::InvalidWeight
345+
WeightError::InvalidWeight
337346
);
338347
}
339348

src/distributions/mod.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,6 @@ pub mod hidden_export {
108108
pub use super::float::IntoFloat; // used by rand_distr
109109
}
110110
pub mod uniform;
111-
#[deprecated(
112-
since = "0.8.0",
113-
note = "use rand::distributions::{WeightedIndex, WeightedError} instead"
114-
)]
115-
#[cfg(feature = "alloc")]
116-
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
117-
pub mod weighted;
118111

119112
pub use self::bernoulli::{Bernoulli, BernoulliError};
120113
pub use self::distribution::{Distribution, DistIter, DistMap};
@@ -126,7 +119,7 @@ pub use self::slice::Slice;
126119
#[doc(inline)]
127120
pub use self::uniform::Uniform;
128121
#[cfg(feature = "alloc")]
129-
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};
122+
pub use self::weighted_index::{Weight, WeightError, WeightedIndex};
130123

131124
#[allow(unused)]
132125
use crate::Rng;

src/distributions/slice.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use alloc::string::String;
1515
/// [`Slice::new`] constructs a distribution referencing a slice and uniformly
1616
/// samples references from the items in the slice. It may do extra work up
1717
/// front to make sampling of multiple values faster; if only one sample from
18-
/// the slice is required, [`SliceRandom::choose`] can be more efficient.
18+
/// the slice is required, [`IndexedRandom::choose`] can be more efficient.
1919
///
2020
/// Steps are taken to avoid bias which might be present in naive
2121
/// implementations; for example `slice[rng.gen() % slice.len()]` samples from
@@ -25,7 +25,7 @@ use alloc::string::String;
2525
/// This distribution samples with replacement; each sample is independent.
2626
/// Sampling without replacement requires state to be retained, and therefore
2727
/// cannot be handled by a distribution; you should instead consider methods
28-
/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`].
28+
/// on [`IndexedRandom`], such as [`IndexedRandom::choose_multiple`].
2929
///
3030
/// # Example
3131
///
@@ -48,21 +48,21 @@ use alloc::string::String;
4848
/// assert!(vowel_string.chars().all(|c| vowels.contains(&c)));
4949
/// ```
5050
///
51-
/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose]
51+
/// For a single sample, [`IndexedRandom::choose`][crate::seq::IndexedRandom::choose]
5252
/// may be preferred:
5353
///
5454
/// ```
55-
/// use rand::seq::SliceRandom;
55+
/// use rand::seq::IndexedRandom;
5656
///
5757
/// let vowels = ['a', 'e', 'i', 'o', 'u'];
5858
/// let mut rng = rand::thread_rng();
5959
///
6060
/// println!("{}", vowels.choose(&mut rng).unwrap())
6161
/// ```
6262
///
63-
/// [`SliceRandom`]: crate::seq::SliceRandom
64-
/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose
65-
/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple
63+
/// [`IndexedRandom`]: crate::seq::IndexedRandom
64+
/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose
65+
/// [`IndexedRandom::choose_multiple`]: crate::seq::IndexedRandom::choose_multiple
6666
#[derive(Debug, Clone, Copy)]
6767
pub struct Slice<'a, T> {
6868
slice: &'a [T],

0 commit comments

Comments
 (0)