Skip to content

Commit 71b18e6

Browse files
fonnesbeckricardoV94
authored andcommitted
Switched clone to True in FunctionGraph calls for sample_jax
1 parent e0592ec commit 71b18e6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc/sampling_jax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_jaxified_logp(model: Model) -> Callable:
7474

7575
logpt = replace_shared_variables([model.logpt()])[0]
7676

77-
logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False)
77+
logpt_fgraph = FunctionGraph(outputs=[logpt], clone=True)
7878
optimize_graph(logpt_fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
7979

8080
# We now jaxify the optimized fgraph
@@ -123,7 +123,7 @@ def _get_log_likelihood(model, samples):
123123
data = {}
124124
for v in model.observed_RVs:
125125
logp_v = replace_shared_variables([model.logpt(v, sum=False)[0]])
126-
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
126+
fgraph = FunctionGraph(model.value_vars, logp_v, clone=True)
127127
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
128128
jax_fn = jax_funcify(fgraph)
129129
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
@@ -229,7 +229,7 @@ def sample_numpyro_nuts(
229229
print("Transforming variables...", file=sys.stdout)
230230
mcmc_samples = {}
231231
for v in vars_to_sample:
232-
fgraph = FunctionGraph(model.value_vars, [v], clone=False)
232+
fgraph = FunctionGraph(model.value_vars, [v], clone=True)
233233
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
234234
jax_fn = jax_funcify(fgraph)
235235
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]

0 commit comments

Comments
 (0)