Skip to content

Commit 25c6772

Browse files
committed
Set default updates for all graph RandomVariables in compile_pymc
1 parent 4fedb60 commit 25c6772

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

pymc/aesaraf.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -961,17 +961,14 @@ def compile_pymc(inputs, outputs, mode=None, **kwargs):
961961
this function is called within a model context and the model `check_bounds` flag
962962
is set to False.
963963
"""
964-
965-
# Avoid circular dependency
966-
from pymc.distributions import NoDistribution
967-
968-
# Set the default update of a NoDistribution RNG so that it is automatically
964+
# Set the default update of RandomVariable's RNG so that it is automatically
969965
# updated after every function call
966+
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
970967
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
971968
for rv in (
972969
node
973970
for node in walk_model(output_to_list, walk_past_rvs=True)
974-
if node.owner and isinstance(node.owner.op, NoDistribution)
971+
if node.owner and isinstance(node.owner.op, RandomVariable)
975972
):
976973
rng = rv.owner.inputs[0]
977974
if not hasattr(rng, "default_update"):

pymc/tests/test_aesaraf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,11 @@ def test_check_bounds_flag():
574574
m.check_bounds = True
575575
with m:
576576
assert np.all(compile_pymc([], bound)() == -np.inf)
577+
578+
579+
def test_compile_pymc_sets_default_updates():
580+
rng = aesara.shared(np.random.default_rng(0))
581+
x = pm.Normal.dist(rng=rng)
582+
assert x.owner.inputs[0] is rng
583+
f = compile_pymc([], x)
584+
assert not np.isclose(f(), f())

0 commit comments

Comments
 (0)