Skip to content

Commit bbe58ec

Browse files
martiningramricardoV94
authored andcommitted
Add chain_method kwarg to sample_numpyro_nuts
Default behaviour is unchanged, but setting `chain_method='vectorized'` can be more efficient when run on a single GPU.
1 parent 7191e61 commit bbe58ec

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc/sampling_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def sample_numpyro_nuts(
141141
var_names=None,
142142
progress_bar=True,
143143
keep_untransformed=False,
144+
chain_method="parallel",
144145
):
145146
from numpyro.infer import MCMC, NUTS
146147

@@ -188,7 +189,7 @@ def sample_numpyro_nuts(
188189
num_samples=draws,
189190
num_chains=chains,
190191
postprocess_fn=None,
191-
chain_method="parallel",
192+
chain_method=chain_method,
192193
progress_bar=progress_bar,
193194
)
194195

0 commit comments

Comments
 (0)