We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7191e61 commit bbe58ecCopy full SHA for bbe58ec
pymc/sampling_jax.py
@@ -141,6 +141,7 @@ def sample_numpyro_nuts(
141
var_names=None,
142
progress_bar=True,
143
keep_untransformed=False,
144
+ chain_method="parallel",
145
):
146
from numpyro.infer import MCMC, NUTS
147
@@ -188,7 +189,7 @@ def sample_numpyro_nuts(
188
189
num_samples=draws,
190
num_chains=chains,
191
postprocess_fn=None,
- chain_method="parallel",
192
+ chain_method=chain_method,
193
progress_bar=progress_bar,
194
)
195
0 commit comments