Skip to content

Commit acdd6d7

Browse files
committed
Add Unix fork protection
1 parent 5953334 commit acdd6d7

File tree

1 file changed

+60
-9
lines changed

1 file changed

+60
-9
lines changed

src/rngs/adapter/reseeding.rs

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ use core::mem::size_of;
1616
use rand_core::{RngCore, CryptoRng, SeedableRng, Error, ErrorKind};
1717
use rand_core::block::{BlockRngCore, BlockRng};
1818

19+
#[cfg(all(feature="std", unix))]
20+
extern crate libc;
21+
1922
/// A wrapper around any PRNG which reseeds the underlying PRNG after it has
2023
/// generated a certain number of random bytes.
2124
///
@@ -126,6 +129,7 @@ struct ReseedingCore<R, Rsdr> {
126129
reseeder: Rsdr,
127130
threshold: i64,
128131
bytes_until_reseed: i64,
132+
fork_counter: u64,
129133
}
130134

131135
impl<R, Rsdr> BlockRngCore for ReseedingCore<R, Rsdr>
@@ -136,8 +140,9 @@ where R: BlockRngCore + SeedableRng,
136140
type Results = <R as BlockRngCore>::Results;
137141

138142
fn generate(&mut self, results: &mut Self::Results) {
139-
if self.bytes_until_reseed <= 0 {
140-
// We get better performance by not calling only `auto_reseed` here
143+
if self.bytes_until_reseed <= 0 ||
144+
self.fork_counter < get_fork_counter() {
145+
// We get better performance by not calling only `reseed` here
141146
// and continuing with the rest of the function, but by directly
142147
// returning from a non-inlined function.
143148
return self.reseed_and_generate(results);
@@ -161,11 +166,14 @@ where R: BlockRngCore + SeedableRng,
161166
/// * `reseeder`: the RNG to use for reseeding.
162167
pub fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self {
163168
assert!(threshold <= ::core::i64::MAX as u64);
169+
register_fork_handler();
170+
164171
ReseedingCore {
165172
inner: rng,
166173
reseeder,
167174
threshold: threshold as i64,
168175
bytes_until_reseed: threshold as i64,
176+
fork_counter: 0,
169177
}
170178
}
171179

@@ -181,9 +189,15 @@ where R: BlockRngCore + SeedableRng,
181189
fn reseed_and_generate(&mut self,
182190
results: &mut <Self as BlockRngCore>::Results)
183191
{
184-
trace!("Reseeding RNG after {} generated bytes",
185-
self.threshold - self.bytes_until_reseed);
186-
let threshold = if let Err(e) = self.reseed() {
192+
let fork_counter = get_fork_counter();
193+
if self.fork_counter < fork_counter {
194+
warn!("Fork detected, reseeding RNG");
195+
} else {
196+
trace!("Reseeding RNG after {} generated bytes",
197+
self.threshold - self.bytes_until_reseed);
198+
}
199+
200+
let threshold = if let Err(e) = self.reseed() {
187201
let delay = match e.kind {
188202
ErrorKind::Transient => 0,
189203
kind @ _ if kind.should_retry() => self.threshold >> 8,
@@ -193,11 +207,13 @@ where R: BlockRngCore + SeedableRng,
193207
error from source: {}", delay, e);
194208
delay
195209
} else {
196-
self.threshold
210+
let num_bytes =
211+
results.as_ref().len() * size_of::<<R as BlockRngCore>::Item>();
212+
self.fork_counter = fork_counter;
213+
self.threshold - num_bytes as i64
197214
};
198-
199-
let num_bytes = results.as_ref().len() * size_of::<<R as BlockRngCore>::Item>();
200-
self.bytes_until_reseed = threshold - num_bytes as i64;
215+
216+
self.bytes_until_reseed = threshold;
201217
self.inner.generate(results);
202218
}
203219
}
@@ -212,6 +228,7 @@ where R: BlockRngCore + SeedableRng + Clone,
212228
reseeder: self.reseeder.clone(),
213229
threshold: self.threshold,
214230
bytes_until_reseed: 0, // reseed clone on first use
231+
fork_counter: self.fork_counter,
215232
}
216233
}
217234
}
@@ -220,6 +237,40 @@ impl<R, Rsdr> CryptoRng for ReseedingCore<R, Rsdr>
220237
where R: BlockRngCore + SeedableRng + CryptoRng,
221238
Rsdr: RngCore + CryptoRng {}
222239

240+
241+
// Fork protection
242+
//
243+
// We implement fork protection on Unix using `pthread_atfork`.
244+
// When the process is forked, we increment `RESEEDING_RNG_FORK_COUNTER`.
245+
// Every `ReseedingRng` stores the last known value of the static in
246+
// `fork_counter`. If the cached `fork_counter` is less than
247+
// `RESEEDING_RNG_FORK_COUNTER`, it is time to reseed this RNG.
248+
//
249+
// If reseeding fails, we don't deal with this by setting a delay, but just
250+
// don't update `fork_counter`, so a reseed is attempted a soon as possible.
251+
#[cfg(all(feature="std", unix))]
252+
static mut RESEEDING_RNG_FORK_COUNTER: u64 = 0;
253+
254+
#[cfg(all(feature="std", unix))]
255+
extern fn forkhandler() {
256+
unsafe { RESEEDING_RNG_FORK_COUNTER += 1; }
257+
}
258+
259+
#[cfg(all(feature="std", unix))]
260+
fn register_fork_handler() {
261+
unsafe {
262+
libc::pthread_atfork(None, None, Some(forkhandler));
263+
}
264+
}
265+
#[cfg(not(all(feature="std", unix)))]
266+
fn register_fork_handler() {}
267+
268+
#[cfg(all(feature="std", unix))]
269+
fn get_fork_counter() -> u64 { unsafe { RESEEDING_RNG_FORK_COUNTER } }
270+
#[cfg(not(all(feature="std", unix)))]
271+
fn get_fork_counter() -> u64 { 0 }
272+
273+
223274
#[cfg(test)]
224275
mod test {
225276
use {Rng, SeedableRng};

0 commit comments

Comments
 (0)