14
14
15
15
16
16
from functools import partial
17
+ from typing import Any
17
18
18
19
import numpy as np
19
20
30
31
from jax ._src .util import prod
31
32
32
33
34
+ PRNG = Any
33
35
UINT_DTYPES = {8 : jnp .uint8 , 16 : jnp .uint16 , 32 : jnp .uint32 , 64 : jnp .uint64 }
34
36
35
37
36
- def PRNGKey (seed : int ) -> jnp .ndarray :
37
- """Create a pseudo-random number generator (PRNG) key given an integer seed.
38
-
39
- Args:
40
- seed: a 64- or 32-bit integer used as the value of the key.
41
-
42
- Returns:
43
- A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
44
- key is constructed from a 64-bit seed by effectively bit-casting to a pair
45
- of uint32 values (or from a 32-bit seed by first padding out with zeros).
46
- """
38
+ def threefry_init (seed : int ) -> jnp .ndarray :
47
39
# Avoid overflowerror in X32 mode by first converting ints to int64.
48
- # This breaks JIT invariance of PRNGKey for large ints, but supports the
49
- # common use-case of instantiating PRNGKey with Python hashes in X32 mode.
40
+ # This breaks JIT invariance of this init function for large ints,
41
+ # but supports the common use-case of calling it with Python hashes
42
+ # in X32 mode.
50
43
if isinstance (seed , int ):
51
44
seed_arr = jnp .asarray (np .int64 (seed ))
52
45
else :
53
46
seed_arr = jnp .asarray (seed )
54
47
if seed_arr .shape :
55
- raise TypeError (f"PRNGKey seed must be a scalar; got { seed !r} ." )
48
+ raise TypeError (f"PRNG seed must be a scalar; got { seed !r} ." )
56
49
if not np .issubdtype (seed_arr .dtype , np .integer ):
57
- raise TypeError (f"PRNGKey seed must be an integer; got { seed !r} " )
50
+ raise TypeError (f"PRNG seed must be an integer; got { seed !r} " )
58
51
59
52
convert = lambda k : lax .reshape (lax .convert_element_type (k , np .uint32 ), [1 ])
60
53
k1 = convert (lax .shift_right_logical (seed_arr , lax ._const (seed_arr , 32 )))
@@ -238,50 +231,30 @@ def threefry_2x32(keypair, count):
238
231
return lax .reshape (out [:- 1 ] if odd_size else out , count .shape )
239
232
240
233
241
- def split (key : jnp .ndarray , num : int = 2 ) -> jnp .ndarray :
242
- """Splits a PRNG key into `num` new keys by adding a leading axis.
243
-
244
- Args:
245
- key: a PRNGKey (an array with shape (2,) and dtype uint32).
246
- num: optional, a positive integer indicating the number of keys to produce
247
- (default 2).
248
-
249
- Returns:
250
- An array with shape (num, 2) and dtype uint32 representing `num` new keys.
251
- """
252
- return _split (key , int (num )) # type: ignore
234
+ def threefry_split (key : jnp .ndarray , num : int = 2 ) -> jnp .ndarray :
235
+ return _threefry_split (key , int (num )) # type: ignore
253
236
254
237
255
238
@partial (jit , static_argnums = (1 ,))
256
- def _split (key , num ) -> jnp .ndarray :
239
+ def _threefry_split (key , num ) -> jnp .ndarray :
257
240
counts = lax .iota (np .uint32 , num * 2 )
258
241
return lax .reshape (threefry_2x32 (key , counts ), (num , 2 ))
259
242
260
243
261
- def fold_in (key : jnp .ndarray , data : int ) -> jnp .ndarray :
262
- """Folds in data to a PRNG key to form a new PRNG key.
263
-
264
- Args:
265
- key: a PRNGKey (an array with shape (2,) and dtype uint32).
266
- data: a 32bit integer representing data to be folded in to the key.
267
-
268
- Returns:
269
- A new PRNGKey that is a deterministic function of the inputs and is
270
- statistically safe for producing a stream of new pseudo-random values.
271
- """
272
- return _fold_in (key , jnp .uint32 (data ))
244
+ def threefry_fold_in (key : jnp .ndarray , data : int ) -> jnp .ndarray :
245
+ return _threefry_fold_in (key , jnp .uint32 (data ))
273
246
274
247
275
248
@jit
276
- def _fold_in (key , data ):
277
- return threefry_2x32 (key , PRNGKey (data ))
249
+ def _threefry_fold_in (key , data ):
250
+ return threefry_2x32 (key , threefry_init (data ))
278
251
279
252
280
253
@partial (jit , static_argnums = (1 , 2 ))
281
- def _random_bits (key , bit_width , shape ):
254
+ def threefry_random_bits (key , bit_width , shape ):
282
255
"""Sample uniform random bits of given width and shape using PRNG key."""
283
256
if not _is_prng_key (key ):
284
- raise TypeError ("_random_bits got invalid prng key." )
257
+ raise TypeError ("random_bits got invalid prng key." )
285
258
if bit_width not in (8 , 16 , 32 , 64 ):
286
259
raise TypeError ("requires 8-, 16-, 32- or 64-bit field width." )
287
260
shape = core .as_named_shape (shape )
@@ -291,15 +264,15 @@ def _random_bits(key, bit_width, shape):
291
264
raise ValueError (f"The shape of axis { name } was specified as { size } , "
292
265
f"but it really is { real_size } " )
293
266
axis_index = lax .axis_index (name )
294
- key = fold_in (key , axis_index )
267
+ key = threefry_fold_in (key , axis_index )
295
268
size = prod (shape .positional )
296
269
max_count = int (np .ceil (bit_width * size / 32 ))
297
270
298
271
nblocks , rem = divmod (max_count , jnp .iinfo (np .uint32 ).max )
299
272
if not nblocks :
300
273
bits = threefry_2x32 (key , lax .iota (np .uint32 , rem ))
301
274
else :
302
- * subkeys , last_key = split (key , nblocks + 1 )
275
+ * subkeys , last_key = threefry_split (key , nblocks + 1 )
303
276
blocks = [threefry_2x32 (k , lax .iota (np .uint32 , jnp .iinfo (np .uint32 ).max ))
304
277
for k in subkeys ]
305
278
last = threefry_2x32 (last_key , lax .iota (np .uint32 , rem ))
@@ -324,3 +297,10 @@ def _random_bits(key, bit_width, shape):
324
297
bits = lax .reshape (bits , (np .uint32 (max_count * 32 // bit_width ),), (1 , 0 ))
325
298
bits = lax .convert_element_type (bits , dtype )[:size ]
326
299
return lax .reshape (bits , shape )
300
+
301
+
302
+ class threefry_prng :
303
+ init = threefry_init
304
+ fold_in = threefry_fold_in
305
+ random_bits = threefry_random_bits
306
+ split = threefry_split
0 commit comments