Skip to content

Commit 5b31ec7

Browse files
authored
Fix bug with default init_nuts (#5622)
1 parent 436f94a commit 5b31ec7

File tree

6 files changed

+116
-74
lines changed

6 files changed

+116
-74
lines changed

pymc/sampling.py

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -279,18 +279,17 @@ def sample(
279279
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
280280
by default. See ``discard_tuned_samples``.
281281
init : str
282-
Initialization method to use for auto-assigned NUTS samplers.
283-
See `pm.init_nuts` for a list of all options.
282+
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
283+
of all options. This argument is ignored when manually passing the NUTS step method.
284284
step : function or iterable of functions
285285
A step function or collection of functions. If there are variables without step methods,
286-
step methods for those variables will be assigned automatically. By default the NUTS step
287-
method will be used, if appropriate to the model; this is a good default for beginning
288-
users.
286+
step methods for those variables will be assigned automatically. By default the NUTS step
287+
method will be used, if appropriate to the model.
289288
n_init : int
290289
Number of iterations of initializer. Only works for 'ADVI' init methods.
291290
initvals : optional, dict, array of dict
292-
Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`.
293-
The keys should be names of transformed random variables.
291+
Dict or list of dicts with initial value strategies to use instead of the defaults from
292+
`Model.initial_values`. The keys should be names of transformed random variables.
294293
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
295294
trace : backend or list
296295
This should be a backend instance, or a list of variables to track.
@@ -317,8 +316,8 @@ def sample(
317316
model : Model (optional if in ``with`` context)
318317
Model to sample from. The model needs to have free random variables.
319318
random_seed : int or list of ints
320-
Random seed(s) used by the sampling steps. A list is accepted if
321-
``cores`` is greater than one.
319+
Random seed(s) used by the sampling steps. A list is accepted if ``cores`` is greater than
320+
one.
322321
discard_tuned_samples : bool
323322
Whether to discard posterior samples of the tune interval.
324323
compute_convergence_checks : bool, default=True
@@ -330,17 +329,17 @@ def sample(
330329
is drawn from.
331330
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
332331
jitter_max_retries : int
333-
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
334-
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
335-
init methods.
332+
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
333+
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
334+
``jitter+adapt_full`` init methods.
336335
return_inferencedata : bool
337-
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
338-
Defaults to `True`.
336+
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
337+
`MultiTrace` (False). Defaults to `True`.
339338
idata_kwargs : dict, optional
340339
Keyword arguments for :func:`pymc.to_inference_data`
341340
mp_ctx : multiprocessing.context.BaseContent
342-
A multiprocessing context for parallel sampling. See multiprocessing
343-
documentation for details.
341+
A multiprocessing context for parallel sampling.
342+
See multiprocessing documentation for details.
344343
345344
Returns
346345
-------
@@ -352,37 +351,28 @@ def sample(
352351
Optional keyword arguments can be passed to ``sample`` to be delivered to the
353352
``step_method``\ s used during sampling.
354353
355-
If your model uses only one step method, you can address step method kwargs
356-
directly. In particular, the NUTS step method has several options including:
354+
For example:
355+
356+
1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
357+
2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
358+
359+
Note that available step names are:
360+
361+
``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
362+
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
363+
``DEMetropolis``, ``DEMetropolisZ``, ``slice``
364+
365+
The NUTS step method has several options including:
357366
358367
* target_accept : float in [0, 1]. The step size is tuned such that we
359368
approximate this acceptance rate. Higher values like 0.9 or 0.95 often
360-
work better for problematic posteriors
369+
work better for problematic posteriors. This argument can be passed directly to sample.
361370
* max_treedepth : The maximum depth of the trajectory tree
362371
* step_scale : float, default 0.25
363372
The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
364373
where n is the dimensionality of the parameter space
365374
366-
If your model uses multiple step methods, aka a Compound Step, then you have
367-
two ways to address arguments to each step method:
368-
369-
A. If you let ``sample()`` automatically assign the ``step_method``\ s,
370-
and you can correctly anticipate what they will be, then you can wrap
371-
step method kwargs in a dict and pass that to sample() with a kwarg set
372-
to the name of the step method.
373-
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
374-
you could send:
375-
376-
1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
377-
2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
378-
379-
Note that available names are:
380-
381-
``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
382-
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
383-
``DEMetropolis``, ``DEMetropolisZ``, ``slice``
384-
385-
B. If you manually declare the ``step_method``\ s, within the ``step``
375+
Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
386376
kwarg, then you can address the ``step_method`` kwargs directly.
387377
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
388378
you could send ::
@@ -422,6 +412,8 @@ def sample(
422412
stacklevel=2,
423413
)
424414
initvals = kwargs.pop("start")
415+
if "target_accept" in kwargs:
416+
kwargs.setdefault("nuts", {"target_accept": kwargs.pop("target_accept")})
425417

426418
model = modelcontext(model)
427419
if not model.free_RVs:
@@ -466,11 +458,37 @@ def sample(
466458

467459
draws += tune
468460

461+
auto_nuts_init = True
462+
if step is not None:
463+
if isinstance(step, CompoundStep):
464+
for method in step.methods:
465+
if isinstance(method, NUTS):
466+
auto_nuts_init = False
467+
elif isinstance(step, NUTS):
468+
auto_nuts_init = False
469+
469470
initial_points = None
470471
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
471472

472473
if isinstance(step, list):
473474
step = CompoundStep(step)
475+
elif isinstance(step, NUTS) and auto_nuts_init:
476+
if "nuts" in kwargs:
477+
nuts_kwargs = kwargs.pop("nuts")
478+
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
479+
_log.info("Auto-assigning NUTS sampler...")
480+
initial_points, step = init_nuts(
481+
init=init,
482+
chains=chains,
483+
n_init=n_init,
484+
model=model,
485+
seeds=random_seed,
486+
progressbar=progressbar,
487+
jitter_max_retries=jitter_max_retries,
488+
tune=tune,
489+
initvals=initvals,
490+
**kwargs,
491+
)
474492

475493
if initial_points is None:
476494
# Time to draw/evaluate numeric start points for each chain.
@@ -2129,7 +2147,7 @@ def draw(
21292147
def _init_jitter(
21302148
model: Model,
21312149
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
2132-
seeds: Sequence[int],
2150+
seeds: Union[List[Any], Tuple[Any, ...], np.ndarray],
21332151
jitter: bool,
21342152
jitter_max_retries: int,
21352153
) -> List[PointType]:
@@ -2186,7 +2204,7 @@ def init_nuts(
21862204
chains: int = 1,
21872205
n_init: int = 500_000,
21882206
model=None,
2189-
seeds: Sequence[int] = None,
2207+
seeds: Iterable[Any] = None,
21902208
progressbar=True,
21912209
jitter_max_retries: int = 10,
21922210
tune: Optional[int] = None,
@@ -2262,8 +2280,7 @@ def init_nuts(
22622280
if not isinstance(init, str):
22632281
raise TypeError("init must be a string.")
22642282

2265-
if init is not None:
2266-
init = init.lower()
2283+
init = init.lower()
22672284

22682285
if init == "auto":
22692286
init = "jitter+adapt_diag"
@@ -2333,7 +2350,8 @@ def init_nuts(
23332350
progressbar=progressbar,
23342351
obj_optimizer=pm.adagrad_window,
23352352
)
2336-
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
2353+
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
2354+
initial_points = [approx_sample[i] for i in range(chains)]
23372355
std_apoint = approx.std.eval()
23382356
cov = std_apoint**2
23392357
mean = approx.mean.get_value()
@@ -2350,7 +2368,8 @@ def init_nuts(
23502368
progressbar=progressbar,
23512369
obj_optimizer=pm.adagrad_window,
23522370
)
2353-
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
2371+
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
2372+
initial_points = [approx_sample[i] for i in range(chains)]
23542373
cov = approx.std.eval() ** 2
23552374
potential = quadpotential.QuadPotentialDiag(cov)
23562375
elif init == "advi_map":
@@ -2364,7 +2383,8 @@ def init_nuts(
23642383
progressbar=progressbar,
23652384
obj_optimizer=pm.adagrad_window,
23662385
)
2367-
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
2386+
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
2387+
initial_points = [approx_sample[i] for i in range(chains)]
23682388
cov = approx.std.eval() ** 2
23692389
potential = quadpotential.QuadPotentialDiag(cov)
23702390
elif init == "map":

pymc/tests/test_data_container.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_sample(self):
5050
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)
5151

5252
prior_trace0 = pm.sample_prior_predictive(1000)
53-
idata = pm.sample(1000, init=None, tune=1000, chains=1)
53+
idata = pm.sample(1000, tune=1000, chains=1)
5454
pp_trace0 = pm.sample_posterior_predictive(idata)
5555

5656
x_shared.set_value(x_pred)
@@ -103,7 +103,6 @@ def test_sample_after_set_data(self):
103103
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
104104
pm.sample(
105105
1000,
106-
init=None,
107106
tune=1000,
108107
chains=1,
109108
compute_convergence_checks=False,
@@ -115,7 +114,6 @@ def test_sample_after_set_data(self):
115114
pm.set_data(new_data={"x": new_x, "y": new_y})
116115
new_idata = pm.sample(
117116
1000,
118-
init=None,
119117
tune=1000,
120118
chains=1,
121119
compute_convergence_checks=False,
@@ -141,7 +139,6 @@ def test_shared_data_as_index(self):
141139
prior_trace = pm.sample_prior_predictive(1000)
142140
idata = pm.sample(
143141
1000,
144-
init=None,
145142
tune=1000,
146143
chains=1,
147144
compute_convergence_checks=False,
@@ -236,7 +233,6 @@ def test_set_data_to_non_data_container_variables(self):
236233
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
237234
pm.sample(
238235
1000,
239-
init=None,
240236
tune=1000,
241237
chains=1,
242238
compute_convergence_checks=False,
@@ -255,7 +251,6 @@ def test_model_to_graphviz_for_model_with_data_container(self):
255251
pm.Normal("obs", beta * x, obs_sigma, observed=y)
256252
pm.sample(
257253
1000,
258-
init=None,
259254
tune=1000,
260255
chains=1,
261256
compute_convergence_checks=False,

pymc/tests/test_quadpotential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def energy(self, x, velocity=None):
150150
pot = Potential(floatX([1]))
151151
with model:
152152
step = pymc.NUTS(potential=pot)
153-
pymc.sample(10, init=None, step=step, chains=1)
153+
pymc.sample(10, step=step, chains=1)
154154
assert called
155155

156156

pymc/tests/test_sampling.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,13 @@ def test_sample_init(self):
123123
def test_sample_args(self):
124124
with self.model:
125125
with pytest.raises(ValueError) as excinfo:
126-
pm.sample(50, tune=0, init=None, foo=1)
126+
pm.sample(50, tune=0, foo=1)
127127
assert "'foo'" in str(excinfo.value)
128128

129129
with pytest.raises(ValueError) as excinfo:
130-
pm.sample(50, tune=0, init=None, foo={})
130+
pm.sample(50, tune=0, foo={})
131131
assert "foo" in str(excinfo.value)
132132

133-
with pytest.raises(ValueError) as excinfo:
134-
pm.sample(10, tune=0, init=None, target_accept=0.9)
135-
assert "target_accept" in str(excinfo.value)
136-
137133
def test_iter_sample(self):
138134
with self.model:
139135
samps = pm.sampling.iter_sample(
@@ -918,23 +914,12 @@ def check_exec_nuts_init(method):
918914
assert model.b.tag.value_var.name in start[0]
919915

920916

921-
@pytest.mark.xfail(reason="ADVI not refactored for v4")
922917
@pytest.mark.parametrize(
923918
"method",
924919
[
925920
"advi",
926921
"ADVI+adapt_diag",
927-
"advi+adapt_diag_grad",
928922
"advi_map",
929-
],
930-
)
931-
def test_exec_nuts_advi_init(method):
932-
check_exec_nuts_init(method)
933-
934-
935-
@pytest.mark.parametrize(
936-
"method",
937-
[
938923
"jitter+adapt_diag",
939924
"adapt_diag",
940925
"map",
@@ -1302,3 +1287,45 @@ def test_draw_different_samples(self):
13021287
x_draws_1 = pm.draw(x, 100)
13031288
x_draws_2 = pm.draw(x, 100)
13041289
assert not np.all(np.isclose(x_draws_1, x_draws_2))
1290+
1291+
1292+
class test_step_args(SeededTest):
1293+
with pm.Model() as model:
1294+
a = pm.Normal("a")
1295+
idata0 = pm.sample(target_accept=0.5)
1296+
idata1 = pm.sample(nuts={"target_accept": 0.5})
1297+
1298+
npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
1299+
npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
1300+
1301+
with pm.Model() as model:
1302+
a = pm.Normal("a")
1303+
b = pm.Poisson("b", 1)
1304+
idata0 = pm.sample(target_accept=0.5)
1305+
idata1 = pm.sample(nuts={"target_accept": 0.5}, metropolis={"scaling": 0})
1306+
1307+
npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
1308+
npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
1309+
npt.assert_allclose(idata1.sample_stats.scaling, 0)
1310+
1311+
1312+
def test_init_nuts(caplog):
1313+
with pm.Model() as model:
1314+
a = pm.Normal("a")
1315+
pm.sample(10, tune=10)
1316+
assert "Initializing NUTS" in caplog.text
1317+
1318+
1319+
def test_no_init_nuts_step(caplog):
1320+
with pm.Model() as model:
1321+
a = pm.Normal("a")
1322+
pm.sample(10, tune=10, step=pm.NUTS([a]))
1323+
assert "Initializing NUTS" not in caplog.text
1324+
1325+
1326+
def test_no_init_nuts_compound(caplog):
1327+
with pm.Model() as model:
1328+
a = pm.Normal("a")
1329+
b = pm.Poisson("b", 1)
1330+
pm.sample(10, tune=10)
1331+
assert "Initializing NUTS" not in caplog.text

pymc/tests/test_shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_sample(self):
4444
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)
4545
prior_trace0 = pm.sample_prior_predictive(1000)
4646

47-
idata = pm.sample(1000, init=None, tune=1000, chains=1)
47+
idata = pm.sample(1000, tune=1000, chains=1)
4848
pp_trace0 = pm.sample_posterior_predictive(idata)
4949

5050
x_shared.set_value(x_pred)

0 commit comments

Comments
 (0)