From 3c9f3ef0a81ae5f71ee0bc961fc7eef6264b6482 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 15 Nov 2021 15:45:05 +0100 Subject: [PATCH] Fix `rvs_to_value_vars` inplace update bug --- pymc/aesaraf.py | 12 +++++++++++- pymc/tests/test_aesaraf.py | 23 ++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 478d59070e..3a3203aedb 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -335,7 +335,7 @@ def rvs_to_value_vars( initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, **kwargs, ) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: - """Replace random variables in graphs with their value variables. + """Clone and replace random variables in graphs with their value variables. This will *not* recompute test values in the resulting graphs. @@ -383,6 +383,16 @@ def transform_replacements(var, replacements): # Walk the transformed variable and make replacements return [trans_rv_value] + # Clone original graphs + inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] + equiv = clone_get_equiv(inputs, graphs, False, False, {}) + graphs = [equiv[n] for n in graphs] + + if initial_replacements: + initial_replacements = { + equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items() + } + return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 1cf845555b..a41625ac66 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -23,7 +23,7 @@ import pytest import scipy.sparse as sps -from aesara.graph.basic import Constant, Variable, ancestors +from aesara.graph.basic import Constant, Variable, ancestors, equal_computations from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -529,3 +529,24 @@ def test_rvs_to_value_vars(): assert a_value_var in res_ancestors assert b_value_var in res_ancestors assert c_value_var in res_ancestors + + +def test_rvs_to_value_vars_nested(): + # Test that calling rvs_to_value_vars in models with nested transformations + # does not change the original rvs in place. See issue #5172 + with pm.Model() as m: + one = pm.LogNormal("one", mu=0) + two = pm.LogNormal("two", mu=at.log(one)) + + # We add potentials or deterministics that are not in topological order + pm.Potential("two_pot", two) + pm.Potential("one_pot", one) + + before = aesara.clone_replace(m.free_RVs) + + # This call would change the model free_RVs in place in #5172 + res, _ = rvs_to_value_vars(m.potentials, apply_transforms=True) + + after = aesara.clone_replace(m.free_RVs) + + assert equal_computations(before, after)