Skip to content

Commit 339baaf

Browse files
committed
parameterize random samplers by PRNG implementation
1 parent db0ccc7 commit 339baaf

File tree

4 files changed

+265
-197
lines changed

4 files changed

+265
-197
lines changed

jax/_src/prng.py

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from functools import partial
17+
from typing import Any
1718

1819
import numpy as np
1920

@@ -30,31 +31,23 @@
3031
from jax._src.util import prod
3132

3233

34+
PRNG = Any
3335
UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}
3436

3537

36-
def PRNGKey(seed: int) -> jnp.ndarray:
37-
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
38-
39-
Args:
40-
seed: a 64- or 32-bit integer used as the value of the key.
41-
42-
Returns:
43-
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
44-
key is constructed from a 64-bit seed by effectively bit-casting to a pair
45-
of uint32 values (or from a 32-bit seed by first padding out with zeros).
46-
"""
38+
def threefry_init(seed: int) -> jnp.ndarray:
4739
# Avoid overflowerror in X32 mode by first converting ints to int64.
48-
# This breaks JIT invariance of PRNGKey for large ints, but supports the
49-
# common use-case of instantiating PRNGKey with Python hashes in X32 mode.
40+
# This breaks JIT invariance of this init function for large ints,
41+
# but supports the common use-case of calling it with Python hashes
42+
# in X32 mode.
5043
if isinstance(seed, int):
5144
seed_arr = jnp.asarray(np.int64(seed))
5245
else:
5346
seed_arr = jnp.asarray(seed)
5447
if seed_arr.shape:
55-
raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
48+
raise TypeError(f"PRNG seed must be a scalar; got {seed!r}.")
5649
if not np.issubdtype(seed_arr.dtype, np.integer):
57-
raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}")
50+
raise TypeError(f"PRNG seed must be an integer; got {seed!r}")
5851

5952
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
6053
k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
@@ -238,50 +231,30 @@ def threefry_2x32(keypair, count):
238231
return lax.reshape(out[:-1] if odd_size else out, count.shape)
239232

240233

241-
def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
242-
"""Splits a PRNG key into `num` new keys by adding a leading axis.
243-
244-
Args:
245-
key: a PRNGKey (an array with shape (2,) and dtype uint32).
246-
num: optional, a positive integer indicating the number of keys to produce
247-
(default 2).
248-
249-
Returns:
250-
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
251-
"""
252-
return _split(key, int(num)) # type: ignore
234+
def threefry_split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
235+
return _threefry_split(key, int(num)) # type: ignore
253236

254237

255238
@partial(jit, static_argnums=(1,))
256-
def _split(key, num) -> jnp.ndarray:
239+
def _threefry_split(key, num) -> jnp.ndarray:
257240
counts = lax.iota(np.uint32, num * 2)
258241
return lax.reshape(threefry_2x32(key, counts), (num, 2))
259242

260243

261-
def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
262-
"""Folds in data to a PRNG key to form a new PRNG key.
263-
264-
Args:
265-
key: a PRNGKey (an array with shape (2,) and dtype uint32).
266-
data: a 32bit integer representing data to be folded in to the key.
267-
268-
Returns:
269-
A new PRNGKey that is a deterministic function of the inputs and is
270-
statistically safe for producing a stream of new pseudo-random values.
271-
"""
272-
return _fold_in(key, jnp.uint32(data))
244+
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
245+
return _threefry_fold_in(key, jnp.uint32(data))
273246

274247

275248
@jit
276-
def _fold_in(key, data):
277-
return threefry_2x32(key, PRNGKey(data))
249+
def _threefry_fold_in(key, data):
250+
return threefry_2x32(key, threefry_init(data))
278251

279252

280253
@partial(jit, static_argnums=(1, 2))
281-
def _random_bits(key, bit_width, shape):
254+
def threefry_random_bits(key, bit_width, shape):
282255
"""Sample uniform random bits of given width and shape using PRNG key."""
283256
if not _is_prng_key(key):
284-
raise TypeError("_random_bits got invalid prng key.")
257+
raise TypeError("random_bits got invalid prng key.")
285258
if bit_width not in (8, 16, 32, 64):
286259
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
287260
shape = core.as_named_shape(shape)
@@ -291,15 +264,15 @@ def _random_bits(key, bit_width, shape):
291264
raise ValueError(f"The shape of axis {name} was specified as {size}, "
292265
f"but it really is {real_size}")
293266
axis_index = lax.axis_index(name)
294-
key = fold_in(key, axis_index)
267+
key = threefry_fold_in(key, axis_index)
295268
size = prod(shape.positional)
296269
max_count = int(np.ceil(bit_width * size / 32))
297270

298271
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
299272
if not nblocks:
300273
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
301274
else:
302-
*subkeys, last_key = split(key, nblocks + 1)
275+
*subkeys, last_key = threefry_split(key, nblocks + 1)
303276
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
304277
for k in subkeys]
305278
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
@@ -324,3 +297,10 @@ def _random_bits(key, bit_width, shape):
324297
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
325298
bits = lax.convert_element_type(bits, dtype)[:size]
326299
return lax.reshape(bits, shape)
300+
301+
302+
class threefry_prng:
303+
init = threefry_init
304+
fold_in = threefry_fold_in
305+
random_bits = threefry_random_bits
306+
split = threefry_split

0 commit comments

Comments
 (0)