-
Notifications
You must be signed in to change notification settings - Fork 285
Expand file tree
/
Copy pathprimitives.py
More file actions
450 lines (381 loc) · 17 KB
/
primitives.py
File metadata and controls
450 lines (381 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from contextlib import ExitStack, contextmanager
import functools
from jax import lax, random
import jax.numpy as jnp
import numpyro
from numpyro.distributions.distribution import Distribution
from numpyro.util import identity
_PYRO_STACK = []
CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size'])
def apply_stack(msg):
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
if msg['value'] is None:
if msg['type'] == 'sample':
msg['value'], msg['intermediates'] = msg['fn'](*msg['args'],
sample_intermediates=True,
**msg['kwargs'])
else:
msg['value'] = msg['fn'](*msg['args'], **msg['kwargs'])
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer-1:]:
handler.postprocess_message(msg)
return msg
class Messenger(object):
def __init__(self, fn=None):
if fn is not None and not callable(fn):
raise ValueError("Expected `fn` to be a Python callable object; "
"instead found type(fn) = {}.".format(type(fn)))
self.fn = fn
functools.update_wrapper(self, fn, updated=[])
def __enter__(self):
_PYRO_STACK.append(self)
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
else:
# NB: this mimics Pyro exception handling
# the wrapped function or block raised an exception
# handler exception handling:
# when the callee or enclosed block raises an exception,
# find this handler's position in the stack,
# then remove it and everything below it in the stack.
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
with self:
return self.fn(*args, **kwargs)
def sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
:class:`~numpyro.handlers.substitute`.
.. note::
By design, `sample` primitive is meant to be used inside a NumPyro model.
Then :class:`~numpyro.handlers.seed` handler is used to inject a random
state to `fn`. In those situations, `rng_key` keyword will take no
effect.
:param str name: name of the sample site.
:param fn: a stochastic function that returns a sample.
:param numpy.ndarray obs: observed value
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:param dict infer: an optional dictionary containing additional information
for inference algorithms. For example, if `fn` is a discrete distribution,
setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize
this discrete latent site.
:return: sample from the stochastic `fn`.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return fn(rng_key=rng_key, sample_shape=sample_shape)
# Otherwise, we initialize a message...
initial_msg = {
'type': 'sample',
'name': name,
'fn': fn,
'args': (),
'kwargs': {'rng_key': rng_key, 'sample_shape': sample_shape},
'value': obs,
'scale': None,
'is_observed': obs is not None,
'intermediates': [],
'cond_indep_stack': [],
'infer': {} if infer is None else infer,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
def param(name, init_value=None, **kwargs):
"""
Annotate the given site as an optimizable parameter for use with
:mod:`jax.experimental.optimizers`. For an example of how `param` statements
can be used in inference algorithms, refer to :func:`~numpyro.svi.svi`.
:param str name: name of site.
:param numpy.ndarray init_value: initial value specified by the user. Note that
the onus of using this to initialize the optimizer is on the user /
inference algorithm, since there is no global parameter store in
NumPyro.
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
to batching. Dimension to the left of this will be considered batch
dimensions; if the param statement is inside a subsampled plate, then
corresponding batch dimensions of the parameter will be correspondingly
subsampled. If unspecified, all dimensions will be considered event
dims and no subsampling will be performed.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return init_value
# Otherwise, we initialize a message...
initial_msg = {
'type': 'param',
'name': name,
'fn': identity,
'args': (init_value,),
'kwargs': kwargs,
'value': None,
'scale': None,
'cond_indep_stack': [],
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
def deterministic(name, value):
"""
Used to designate deterministic sites in the model. Note that most effect
handlers will not operate on deterministic sites (except
:func:`~numpyro.handlers.trace`), so deterministic sites should be
side-effect free. The use case for deterministic nodes is to record any
values in the model execution trace.
:param str name: name of the deterministic site.
:param numpy.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return value
initial_msg = {
'type': 'deterministic',
'name': name,
'value': value,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
def module(name, nn, input_shape=None):
"""
Declare a :mod:`~jax.experimental.stax` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
:param str name: name of the module to be registered.
:param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.experimental.stax`
constructor function.
:param tuple input_shape: shape of the input taken by the
neural network.
:return: a `apply_fn` with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
module_key = name + '$params'
nn_init, nn_apply = nn
nn_params = param(module_key)
if nn_params is None:
if input_shape is None:
raise ValueError('Valid value for `input_shape` needed to initialize.')
rng_key = prng_key()
_, nn_params = nn_init(rng_key, input_shape)
param(module_key, nn_params)
return functools.partial(nn_apply, nn_params)
def _subsample_fn(size, subsample_size, rng_key=None):
assert rng_key is not None, "Missing random key to generate subsample indices."
return random.permutation(rng_key, size)[:subsample_size]
class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.
.. note:: This can be used to subsample minibatches of data:
.. code-block:: python
with plate("data", len(data), subsample_size=100) as ind:
batch = data[ind]
assert len(batch) == 100
:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
This can be used to apply a scaling factor by inference algorithms. e.g.
when computing ELBO using a mini-batch.
:param int dim: Optional argument to specify which dimension in the tensor
is used as the plate dim. If `None` (default), the leftmost available dim
is allocated.
"""
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim, self._indices = self._subsample(
self.name, self.size, subsample_size, dim)
self.subsample_size = self._indices.shape[0]
super(plate, self).__init__()
# XXX: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim):
msg = {
'type': 'plate',
'fn': _subsample_fn,
'name': name,
'args': (size, subsample_size),
'kwargs': {'rng_key': None},
'value': (None
if (subsample_size is not None and size != subsample_size)
else jnp.arange(size)),
'scale': 1.0,
'cond_indep_stack': [],
}
apply_stack(msg)
subsample = msg['value']
if subsample_size is not None and subsample_size != subsample.shape[0]:
raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)) +
" Did you accidentally use different subsample_size in the model and guide?")
cond_indep_stack = msg['cond_indep_stack']
occupied_dims = {f.dim for f in cond_indep_stack}
if dim is None:
new_dim = -1
while new_dim in occupied_dims:
new_dim -= 1
dim = new_dim
else:
assert dim not in occupied_dims
return dim, subsample
def __enter__(self):
super().__enter__()
return self._indices
@staticmethod
def _get_batch_shape(cond_indep_stack):
n_dims = max(-f.dim for f in cond_indep_stack)
batch_shape = [1] * n_dims
for f in cond_indep_stack:
batch_shape[f.dim] = f.size
return tuple(batch_shape)
def process_message(self, msg):
if msg['type'] not in ('param', 'sample', 'plate'):
if msg['type'] == 'control_flow':
raise NotImplementedError('Cannot use control flow primitive under a `plate` primitive.'
' Please move those `plate` statements into the control flow'
' body function. See `scan` documentation for more information.')
return
cond_indep_stack = msg['cond_indep_stack']
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
# only expand if fn is Distribution, not a Funsor
if msg['type'] == 'sample' and isinstance(msg['fn'], Distribution):
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape
if 'sample_shape' in msg['kwargs']:
dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape
msg['kwargs']['sample_shape'] = ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
broadcast_shape = lax.broadcast_shapes(trailing_shape, tuple(dist_batch_shape))
batch_shape = expected_shape[:overlap_idx] + broadcast_shape
msg['fn'] = msg['fn'].expand(batch_shape)
if self.size != self.subsample_size:
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size
def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = jnp.shape(msg["value"])
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value
@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
"""
Create a contiguous stack of :class:`plate` s with dimensions::
rightmost_dim - len(sizes), ..., rightmost_dim
:param str prefix: Name prefix for plates.
:param iterable sizes: An iterable of plate sizes.
:param int rightmost_dim: The rightmost dim, counting from the right.
"""
assert rightmost_dim < 0
with ExitStack() as stack:
for i, size in enumerate(reversed(sizes)):
plate_i = plate("{}_{}".format(prefix, i), size, dim=rightmost_dim - i)
stack.enter_context(plate_i)
yield
def factor(name, log_factor):
"""
Factor statement to add arbitrary log probability factor to a
probabilistic model.
:param str name: Name of the trivial sample.
:param numpy.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value)
def prng_key():
"""
A statement to draw a pseudo-random number generator key
:func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler.
:return: a PRNG key of shape (2,) and dtype unit32.
"""
if not _PYRO_STACK:
return
initial_msg = {
'type': 'prng_key',
'fn': lambda rng_key: rng_key,
'args': (),
'kwargs': {'rng_key': None},
'value': None,
}
msg = apply_stack(initial_msg)
return msg['value']
def subsample(data, event_dim):
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
This is typically called on arguments to ``model()`` when subsampling is
performed automatically by :class:`~numpyro.primitives.plate` s by passing
``subsample_size`` kwarg. For example the following are equivalent::
# Version 1. using indexing
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
data = data[ind]
# ...
# Version 2. using numpyro.subsample()
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
data = numpyro.subsample(data, event_dim=0)
# ...
:param numpy.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
"""
if not _PYRO_STACK:
return data
assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
'type': 'subsample',
'value': data,
'kwargs': {'event_dim': event_dim}
}
msg = apply_stack(initial_msg)
return msg['value']