Skip to content

Bring back sampler argument target_accept #5622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 66 additions & 46 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,17 @@ def sample(
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
by default. See ``discard_tuned_samples``.
init : str
Initialization method to use for auto-assigned NUTS samplers.
See `pm.init_nuts` for a list of all options.
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
of all options. This argument is ignored when manually passing the NUTS step method.
step : function or iterable of functions
A step function or collection of functions. If there are variables without step methods,
step methods for those variables will be assigned automatically. By default the NUTS step
method will be used, if appropriate to the model; this is a good default for beginning
users.
step methods for those variables will be assigned automatically. By default the NUTS step
method will be used, if appropriate to the model.
n_init : int
Number of iterations of initializer. Only works for 'ADVI' init methods.
initvals : optional, dict, array of dict
Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`.
The keys should be names of transformed random variables.
Dict or list of dicts with initial value strategies to use instead of the defaults from
`Model.initial_values`. The keys should be names of transformed random variables.
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
trace : backend or list
This should be a backend instance, or a list of variables to track.
Expand All @@ -317,8 +316,8 @@ def sample(
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
random_seed : int or list of ints
Random seed(s) used by the sampling steps. A list is accepted if
``cores`` is greater than one.
Random seed(s) used by the sampling steps. A list is accepted if ``cores`` is greater than
one.
discard_tuned_samples : bool
Whether to discard posterior samples of the tune interval.
compute_convergence_checks : bool, default=True
Expand All @@ -330,17 +329,17 @@ def sample(
is drawn from.
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
jitter_max_retries : int
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
init methods.
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
``jitter+adapt_full`` init methods.
return_inferencedata : bool
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `True`.
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
`MultiTrace` (False). Defaults to `True`.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
mp_ctx : multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing
documentation for details.
A multiprocessing context for parallel sampling.
See multiprocessing documentation for details.

Returns
-------
Expand All @@ -352,37 +351,28 @@ def sample(
Optional keyword arguments can be passed to ``sample`` to be delivered to the
``step_method``\ s used during sampling.

If your model uses only one step method, you can address step method kwargs
directly. In particular, the NUTS step method has several options including:
For example:

1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}

Note that available step names are:

``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
``DEMetropolis``, ``DEMetropolisZ``, ``slice``

The NUTS step method has several options including:

* target_accept : float in [0, 1]. The step size is tuned such that we
approximate this acceptance rate. Higher values like 0.9 or 0.95 often
work better for problematic posteriors
work better for problematic posteriors. This argument can be passed directly to sample.
* max_treedepth : The maximum depth of the trajectory tree
* step_scale : float, default 0.25
The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
where n is the dimensionality of the parameter space

If your model uses multiple step methods, aka a Compound Step, then you have
two ways to address arguments to each step method:

A. If you let ``sample()`` automatically assign the ``step_method``\ s,
and you can correctly anticipate what they will be, then you can wrap
step method kwargs in a dict and pass that to sample() with a kwarg set
to the name of the step method.
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
you could send:

1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}

Note that available names are:

``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
``DEMetropolis``, ``DEMetropolisZ``, ``slice``

B. If you manually declare the ``step_method``\ s, within the ``step``
Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
kwarg, then you can address the ``step_method`` kwargs directly.
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
you could send ::
Expand Down Expand Up @@ -422,6 +412,8 @@ def sample(
stacklevel=2,
)
initvals = kwargs.pop("start")
if "target_accept" in kwargs:
kwargs.setdefault("nuts", {"target_accept": kwargs.pop("target_accept")})

model = modelcontext(model)
if not model.free_RVs:
Expand Down Expand Up @@ -466,11 +458,37 @@ def sample(

draws += tune

auto_nuts_init = True
if step is not None:
if isinstance(step, CompoundStep):
for method in step.methods:
if isinstance(method, NUTS):
auto_nuts_init = False
elif isinstance(step, NUTS):
auto_nuts_init = False

initial_points = None
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)

if isinstance(step, list):
step = CompoundStep(step)
elif isinstance(step, NUTS) and auto_nuts_init:
if "nuts" in kwargs:
nuts_kwargs = kwargs.pop("nuts")
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
_log.info("Auto-assigning NUTS sampler...")
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
seeds=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)
Comment on lines +475 to +491
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not confident about this. As I commented before, the try/catch block caught not-implemented gradients, which we're probably not testing.

Also I'm sceptical about the pattern of having isinstance(step, NUTS) and if "nuts" in kwargs. It reads like a hack to account for some specific cases, and it's not as general as the previous implementation.

You're the ones who've been looking into this much more than me, so feel free to overrule this.

Copy link
Member

@ricardoV94 ricardoV94 Mar 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assign_step_methods checks for gradients as well. If that does not do a good enough job of testing if NUTS is appropriate then it should be treated as a bug and be fixed there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if initial_points is None:
# Time to draw/evaluate numeric start points for each chain.
Expand Down Expand Up @@ -2129,7 +2147,7 @@ def draw(
def _init_jitter(
model: Model,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
seeds: Sequence[int],
seeds: Union[List[Any], Tuple[Any, ...], np.ndarray],
jitter: bool,
jitter_max_retries: int,
) -> List[PointType]:
Expand Down Expand Up @@ -2186,7 +2204,7 @@ def init_nuts(
chains: int = 1,
n_init: int = 500_000,
model=None,
seeds: Sequence[int] = None,
seeds: Iterable[Any] = None,
progressbar=True,
jitter_max_retries: int = 10,
tune: Optional[int] = None,
Expand Down Expand Up @@ -2262,8 +2280,7 @@ def init_nuts(
if not isinstance(init, str):
raise TypeError("init must be a string.")

if init is not None:
init = init.lower()
init = init.lower()

if init == "auto":
init = "jitter+adapt_diag"
Expand Down Expand Up @@ -2333,7 +2350,8 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
initial_points = [approx_sample[i] for i in range(chains)]
std_apoint = approx.std.eval()
cov = std_apoint**2
mean = approx.mean.get_value()
Expand All @@ -2350,7 +2368,8 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "advi_map":
Expand All @@ -2364,7 +2383,8 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
Expand Down
7 changes: 1 addition & 6 deletions pymc/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_sample(self):
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)

prior_trace0 = pm.sample_prior_predictive(1000)
idata = pm.sample(1000, init=None, tune=1000, chains=1)
idata = pm.sample(1000, tune=1000, chains=1)
pp_trace0 = pm.sample_posterior_predictive(idata)

x_shared.set_value(x_pred)
Expand Down Expand Up @@ -103,7 +103,6 @@ def test_sample_after_set_data(self):
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
pm.sample(
1000,
init=None,
tune=1000,
chains=1,
compute_convergence_checks=False,
Expand All @@ -115,7 +114,6 @@ def test_sample_after_set_data(self):
pm.set_data(new_data={"x": new_x, "y": new_y})
new_idata = pm.sample(
1000,
init=None,
tune=1000,
chains=1,
compute_convergence_checks=False,
Expand All @@ -141,7 +139,6 @@ def test_shared_data_as_index(self):
prior_trace = pm.sample_prior_predictive(1000)
idata = pm.sample(
1000,
init=None,
tune=1000,
chains=1,
compute_convergence_checks=False,
Expand Down Expand Up @@ -236,7 +233,6 @@ def test_set_data_to_non_data_container_variables(self):
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
pm.sample(
1000,
init=None,
tune=1000,
chains=1,
compute_convergence_checks=False,
Expand All @@ -255,7 +251,6 @@ def test_model_to_graphviz_for_model_with_data_container(self):
pm.Normal("obs", beta * x, obs_sigma, observed=y)
pm.sample(
1000,
init=None,
tune=1000,
chains=1,
compute_convergence_checks=False,
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def energy(self, x, velocity=None):
pot = Potential(floatX([1]))
with model:
step = pymc.NUTS(potential=pot)
pymc.sample(10, init=None, step=step, chains=1)
pymc.sample(10, step=step, chains=1)
assert called


Expand Down
61 changes: 44 additions & 17 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,13 @@ def test_sample_init(self):
def test_sample_args(self):
with self.model:
with pytest.raises(ValueError) as excinfo:
pm.sample(50, tune=0, init=None, foo=1)
pm.sample(50, tune=0, foo=1)
assert "'foo'" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
pm.sample(50, tune=0, init=None, foo={})
pm.sample(50, tune=0, foo={})
assert "foo" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
pm.sample(10, tune=0, init=None, target_accept=0.9)
assert "target_accept" in str(excinfo.value)

def test_iter_sample(self):
with self.model:
samps = pm.sampling.iter_sample(
Expand Down Expand Up @@ -918,23 +914,12 @@ def check_exec_nuts_init(method):
assert model.b.tag.value_var.name in start[0]


@pytest.mark.xfail(reason="ADVI not refactored for v4")
@pytest.mark.parametrize(
"method",
[
"advi",
"ADVI+adapt_diag",
"advi+adapt_diag_grad",
"advi_map",
],
)
def test_exec_nuts_advi_init(method):
check_exec_nuts_init(method)


@pytest.mark.parametrize(
"method",
[
"jitter+adapt_diag",
"adapt_diag",
"map",
Expand Down Expand Up @@ -1302,3 +1287,45 @@ def test_draw_different_samples(self):
x_draws_1 = pm.draw(x, 100)
x_draws_2 = pm.draw(x, 100)
assert not np.all(np.isclose(x_draws_1, x_draws_2))


class test_step_args(SeededTest):
with pm.Model() as model:
a = pm.Normal("a")
idata0 = pm.sample(target_accept=0.5)
idata1 = pm.sample(nuts={"target_accept": 0.5})

npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)

with pm.Model() as model:
a = pm.Normal("a")
b = pm.Poisson("b", 1)
idata0 = pm.sample(target_accept=0.5)
idata1 = pm.sample(nuts={"target_accept": 0.5}, metropolis={"scaling": 0})

npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
npt.assert_allclose(idata1.sample_stats.scaling, 0)


def test_init_nuts(caplog):
with pm.Model() as model:
a = pm.Normal("a")
pm.sample(10, tune=10)
assert "Initializing NUTS" in caplog.text


def test_no_init_nuts_step(caplog):
with pm.Model() as model:
a = pm.Normal("a")
pm.sample(10, tune=10, step=pm.NUTS([a]))
assert "Initializing NUTS" not in caplog.text


def test_no_init_nuts_compound(caplog):
with pm.Model() as model:
a = pm.Normal("a")
b = pm.Poisson("b", 1)
pm.sample(10, tune=10)
assert "Initializing NUTS" not in caplog.text
2 changes: 1 addition & 1 deletion pymc/tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_sample(self):
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)
prior_trace0 = pm.sample_prior_predictive(1000)

idata = pm.sample(1000, init=None, tune=1000, chains=1)
idata = pm.sample(1000, tune=1000, chains=1)
pp_trace0 = pm.sample_posterior_predictive(idata)

x_shared.set_value(x_pred)
Expand Down
Loading