Skip to content

Use same API for defining internal and external Nuts kwargs #6757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Tracked by #7053
ricardoV94 opened this issue Jun 7, 2023 · 2 comments
Open
Tracked by #7053

Use same API for defining internal and external Nuts kwargs #6757

ricardoV94 opened this issue Jun 7, 2023 · 2 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2023

Description

User on discourse reported:

How can I set the maximum tree depth for the NUTS method from the numpyro library?
The way described in the test file test_mcmc_external.py doesn’t work:

import pymc as pm
import numpy as np

with pm.Model():
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "numpyro",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.

I don't know if nuts should be converted to nuts_kwargs, but even if a user were to pass nuts_kwargs to sample, those wouldn't make it to the sample_numpyro_nuts function because we drop arbitrary kwargs passed here:

**kwargs,

@fonnesbeck
Copy link
Member

fonnesbeck commented Jun 9, 2023

I could have sworn the the NUTS arguments dict was called nuts_sampler_kwargs.

Yes, it is:

    idata_kwargs: Optional[Dict],
    nuts_sampler_kwargs: Optional[Dict],
    **kwargs,

@ricardoV94
Copy link
Member Author

Shouldn't we use the same API for passing kwargs to the PyMC nuts?

pm.sample(..., {"nuts": ...})`

@ricardoV94 ricardoV94 changed the title Numpyro sampler nuts_kwargs can't be passed from sample Use same API for defining internal and external Nuts kwargs Feb 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants