@@ -279,18 +279,17 @@ def sample(
279
279
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
280
280
by default. See ``discard_tuned_samples``.
281
281
init : str
282
- Initialization method to use for auto-assigned NUTS samplers.
283
- See `pm.init_nuts` for a list of all options .
282
+ Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
283
+ of all options. This argument is ignored when manually passing the NUTS step method .
284
284
step : function or iterable of functions
285
285
A step function or collection of functions. If there are variables without step methods,
286
- step methods for those variables will be assigned automatically. By default the NUTS step
287
- method will be used, if appropriate to the model; this is a good default for beginning
288
- users.
286
+ step methods for those variables will be assigned automatically. By default the NUTS step
287
+ method will be used, if appropriate to the model.
289
288
n_init : int
290
289
Number of iterations of initializer. Only works for 'ADVI' init methods.
291
290
initvals : optional, dict, array of dict
292
- Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`.
293
- The keys should be names of transformed random variables.
291
+ Dict or list of dicts with initial value strategies to use instead of the defaults from
292
+ `Model.initial_values`. The keys should be names of transformed random variables.
294
293
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
295
294
trace : backend or list
296
295
This should be a backend instance, or a list of variables to track.
@@ -317,8 +316,8 @@ def sample(
317
316
model : Model (optional if in ``with`` context)
318
317
Model to sample from. The model needs to have free random variables.
319
318
random_seed : int or list of ints
320
- Random seed(s) used by the sampling steps. A list is accepted if
321
- ``cores`` is greater than one.
319
+ Random seed(s) used by the sampling steps. A list is accepted if ``cores`` is greater than
320
+ one.
322
321
discard_tuned_samples : bool
323
322
Whether to discard posterior samples of the tune interval.
324
323
compute_convergence_checks : bool, default=True
@@ -330,17 +329,17 @@ def sample(
330
329
is drawn from.
331
330
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
332
331
jitter_max_retries : int
333
- Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
334
- that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
335
- init methods.
332
+ Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
333
+ jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
334
+ ``jitter+adapt_full`` init methods.
336
335
return_inferencedata : bool
337
- Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
338
- Defaults to `True`.
336
+ Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
337
+ `MultiTrace` (False). Defaults to `True`.
339
338
idata_kwargs : dict, optional
340
339
Keyword arguments for :func:`pymc.to_inference_data`
341
340
mp_ctx : multiprocessing.context.BaseContent
342
- A multiprocessing context for parallel sampling. See multiprocessing
343
- documentation for details.
341
+ A multiprocessing context for parallel sampling.
342
+ See multiprocessing documentation for details.
344
343
345
344
Returns
346
345
-------
@@ -352,37 +351,28 @@ def sample(
352
351
Optional keyword arguments can be passed to ``sample`` to be delivered to the
353
352
``step_method``\ s used during sampling.
354
353
355
- If your model uses only one step method, you can address step method kwargs
356
- directly. In particular, the NUTS step method has several options including:
354
+ For example:
355
+
356
+ 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
357
+ 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
358
+
359
+ Note that available step names are:
360
+
361
+ ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
362
+ ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
363
+ ``DEMetropolis``, ``DEMetropolisZ``, ``slice``
364
+
365
+ The NUTS step method has several options including:
357
366
358
367
* target_accept : float in [0, 1]. The step size is tuned such that we
359
368
approximate this acceptance rate. Higher values like 0.9 or 0.95 often
360
- work better for problematic posteriors
369
+ work better for problematic posteriors. This argument can be passed directly to sample.
361
370
* max_treedepth : The maximum depth of the trajectory tree
362
371
* step_scale : float, default 0.25
363
372
The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
364
373
where n is the dimensionality of the parameter space
365
374
366
- If your model uses multiple step methods, aka a Compound Step, then you have
367
- two ways to address arguments to each step method:
368
-
369
- A. If you let ``sample()`` automatically assign the ``step_method``\ s,
370
- and you can correctly anticipate what they will be, then you can wrap
371
- step method kwargs in a dict and pass that to sample() with a kwarg set
372
- to the name of the step method.
373
- e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
374
- you could send:
375
-
376
- 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
377
- 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
378
-
379
- Note that available names are:
380
-
381
- ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
382
- ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
383
- ``DEMetropolis``, ``DEMetropolisZ``, ``slice``
384
-
385
- B. If you manually declare the ``step_method``\ s, within the ``step``
375
+ Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
386
376
kwarg, then you can address the ``step_method`` kwargs directly.
387
377
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
388
378
you could send ::
@@ -422,6 +412,8 @@ def sample(
422
412
stacklevel = 2 ,
423
413
)
424
414
initvals = kwargs .pop ("start" )
415
+ if "target_accept" in kwargs :
416
+ kwargs .setdefault ("nuts" , {"target_accept" : kwargs .pop ("target_accept" )})
425
417
426
418
model = modelcontext (model )
427
419
if not model .free_RVs :
@@ -466,11 +458,37 @@ def sample(
466
458
467
459
draws += tune
468
460
461
+ auto_nuts_init = True
462
+ if step is not None :
463
+ if isinstance (step , CompoundStep ):
464
+ for method in step .methods :
465
+ if isinstance (method , NUTS ):
466
+ auto_nuts_init = False
467
+ elif isinstance (step , NUTS ):
468
+ auto_nuts_init = False
469
+
469
470
initial_points = None
470
471
step = assign_step_methods (model , step , methods = pm .STEP_METHODS , step_kwargs = kwargs )
471
472
472
473
if isinstance (step , list ):
473
474
step = CompoundStep (step )
475
+ elif isinstance (step , NUTS ) and auto_nuts_init :
476
+ if "nuts" in kwargs :
477
+ nuts_kwargs = kwargs .pop ("nuts" )
478
+ [kwargs .setdefault (k , v ) for k , v in nuts_kwargs .items ()]
479
+ _log .info ("Auto-assigning NUTS sampler..." )
480
+ initial_points , step = init_nuts (
481
+ init = init ,
482
+ chains = chains ,
483
+ n_init = n_init ,
484
+ model = model ,
485
+ seeds = random_seed ,
486
+ progressbar = progressbar ,
487
+ jitter_max_retries = jitter_max_retries ,
488
+ tune = tune ,
489
+ initvals = initvals ,
490
+ ** kwargs ,
491
+ )
474
492
475
493
if initial_points is None :
476
494
# Time to draw/evaluate numeric start points for each chain.
@@ -2129,7 +2147,7 @@ def draw(
2129
2147
def _init_jitter (
2130
2148
model : Model ,
2131
2149
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]],
2132
- seeds : Sequence [ int ],
2150
+ seeds : Union [ List [ Any ], Tuple [ Any , ...], np . ndarray ],
2133
2151
jitter : bool ,
2134
2152
jitter_max_retries : int ,
2135
2153
) -> List [PointType ]:
@@ -2186,7 +2204,7 @@ def init_nuts(
2186
2204
chains : int = 1 ,
2187
2205
n_init : int = 500_000 ,
2188
2206
model = None ,
2189
- seeds : Sequence [ int ] = None ,
2207
+ seeds : Iterable [ Any ] = None ,
2190
2208
progressbar = True ,
2191
2209
jitter_max_retries : int = 10 ,
2192
2210
tune : Optional [int ] = None ,
@@ -2262,8 +2280,7 @@ def init_nuts(
2262
2280
if not isinstance (init , str ):
2263
2281
raise TypeError ("init must be a string." )
2264
2282
2265
- if init is not None :
2266
- init = init .lower ()
2283
+ init = init .lower ()
2267
2284
2268
2285
if init == "auto" :
2269
2286
init = "jitter+adapt_diag"
@@ -2333,7 +2350,8 @@ def init_nuts(
2333
2350
progressbar = progressbar ,
2334
2351
obj_optimizer = pm .adagrad_window ,
2335
2352
)
2336
- initial_points = list (approx .sample (draws = chains , return_inferencedata = False ))
2353
+ approx_sample = approx .sample (draws = chains , return_inferencedata = False )
2354
+ initial_points = [approx_sample [i ] for i in range (chains )]
2337
2355
std_apoint = approx .std .eval ()
2338
2356
cov = std_apoint ** 2
2339
2357
mean = approx .mean .get_value ()
@@ -2350,7 +2368,8 @@ def init_nuts(
2350
2368
progressbar = progressbar ,
2351
2369
obj_optimizer = pm .adagrad_window ,
2352
2370
)
2353
- initial_points = list (approx .sample (draws = chains , return_inferencedata = False ))
2371
+ approx_sample = approx .sample (draws = chains , return_inferencedata = False )
2372
+ initial_points = [approx_sample [i ] for i in range (chains )]
2354
2373
cov = approx .std .eval () ** 2
2355
2374
potential = quadpotential .QuadPotentialDiag (cov )
2356
2375
elif init == "advi_map" :
@@ -2364,7 +2383,8 @@ def init_nuts(
2364
2383
progressbar = progressbar ,
2365
2384
obj_optimizer = pm .adagrad_window ,
2366
2385
)
2367
- initial_points = list (approx .sample (draws = chains , return_inferencedata = False ))
2386
+ approx_sample = approx .sample (draws = chains , return_inferencedata = False )
2387
+ initial_points = [approx_sample [i ] for i in range (chains )]
2368
2388
cov = approx .std .eval () ** 2
2369
2389
potential = quadpotential .QuadPotentialDiag (cov )
2370
2390
elif init == "map" :
0 commit comments