From 1262a442cf63be37f1be44253c6cdec122244a52 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Jun 2023 13:09:58 +0200 Subject: [PATCH 1/4] Make fgraph Deterministic conversion logic more robust --- pymc_experimental/utils/model_fgraph.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py index 9b7993ca..c2ec84ca 100644 --- a/pymc_experimental/utils/model_fgraph.py +++ b/pymc_experimental/utils/model_fgraph.py @@ -237,6 +237,14 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model: See: fgraph_from_model """ + + def first_non_model_var(var): + if var.owner and isinstance(var.owner.op, ModelVar): + new_var = var.owner.inputs[0] + return first_non_model_var(new_var) + else: + return var + model = Model() if model.parent is not None: raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context") @@ -244,7 +252,6 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model: model._dim_lengths = getattr(fgraph, "_dim_lengths", {}) # Replace dummy `ModelVar` Ops by the underlying variables, - # Except for Deterministics which could reintroduce the old graphs fgraph = fgraph.clone() model_dummy_vars = [ model_node.outputs[0] @@ -252,15 +259,14 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model: if isinstance(model_node.op, ModelVar) ] model_dummy_vars_to_vars = { - dummy_var: dummy_var.owner.inputs[0] + # Deterministics could refer to other model variables directly, + # We make sure to replace them by the first non-model variable + dummy_var: first_non_model_var(dummy_var.owner.inputs[0]) for dummy_var in model_dummy_vars - # Don't include Deterministics! - if not isinstance(dummy_var.owner.op, ModelDeterministic) } toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items())) # Populate new PyMC model mappings - non_det_model_vars = set(model_dummy_vars_to_vars.values()) for model_var in model_dummy_vars: if isinstance(model_var.owner.op, ModelFreeRV): var, value, *dims = model_var.owner.inputs @@ -279,10 +285,8 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model: model.potentials.append(var) elif isinstance(model_var.owner.op, ModelDeterministic): var, *dims = model_var.owner.inputs - # Register the original var (not the copy) as the Deterministic - # So it shows in the expected place in graphviz. - # unless it's another model var, in which case we need a copy! - if var in non_det_model_vars: + # If a Deterministic is a direct view on an RV, copy it + if var in model.basic_RVs: var = var.copy() model.deterministics.append(var) elif isinstance(model_var.owner.op, ModelNamed): From 5304406cff224bbe9cdd0365aa6323b1ea3fb9b6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Jun 2023 13:57:03 +0200 Subject: [PATCH 2/4] Allow inlining of Deterministics and Data in fgraph IR --- .../tests/utils/test_model_fgraph.py | 45 +++++++---- pymc_experimental/utils/model_fgraph.py | 74 +++++++++++++------ 2 files changed, 83 insertions(+), 36 deletions(-) diff --git a/pymc_experimental/tests/utils/test_model_fgraph.py b/pymc_experimental/tests/utils/test_model_fgraph.py index 284a9bfd..dcac53fa 100644 --- a/pymc_experimental/tests/utils/test_model_fgraph.py +++ b/pymc_experimental/tests/utils/test_model_fgraph.py @@ -76,7 +76,8 @@ def test_basic(): ) -def test_data(): +@pytest.mark.parametrize("inline_views", (False, True)) +def test_data(inline_views): """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly. Everything should be preserved across new and old models, except for shared RNGs @@ -84,20 +85,32 @@ def test_data(): with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old: x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",)) y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",)) - b0 = pm.ConstantData("b0", 0.0) + b0 = pm.ConstantData("b0", np.zeros(3)) b1 = pm.Normal("b1") mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",)) - m_fgraph, memo = fgraph_from_model(m_old) + m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views) assert isinstance(memo[x].owner.op, ModelNamed) assert isinstance(memo[y].owner.op, ModelNamed) assert isinstance(memo[b0].owner.op, ModelNamed) + mu_inp = memo[mu].owner.inputs[0] + obs = memo[obs] + if not inline_views: + # Add(b0, Mul(FreeRV(b1), x) not Add(Named(b0), Mul(FreeRV(b1), Named(x)) + assert mu_inp.owner.inputs[0] is memo[b0].owner.inputs[0] + assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0] + # ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims) + assert obs.owner.inputs[1] is memo[y].owner.inputs[0] + else: + assert mu_inp.owner.inputs[0] is memo[b0] + assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x] + assert obs.owner.inputs[1] is memo[y] m_new = model_from_fgraph(m_fgraph) # ConstantData is preserved - assert m_new["b0"].data == m_old["b0"].data + assert np.all(m_new["b0"].data == m_old["b0"].data) # Shared non-rng shared variables are preserved assert m_new["x"].container is x.container @@ -114,7 +127,8 @@ def test_data(): np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0]) -def test_deterministics(): +@pytest.mark.parametrize("inline_views", (False, True)) +def test_deterministics(inline_views): """Test handling of deterministics. We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome @@ -140,22 +154,27 @@ def test_deterministics(): assert m["y"].owner.inputs[3] is m["mu"] assert m["y"].owner.inputs[4] is not m["sigma"] - fg, _ = fgraph_from_model(m) + fg, _ = fgraph_from_model(m, inlined_views=inline_views) # Check that no Deterministics are in graph of x to y and y to z x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs # [Det(mu), Det(sigma)] mu = det_mu.owner.inputs[0] sigma = det_sigma.owner.inputs[0] - # [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))] - assert y.owner.inputs[0].owner.inputs[3] is mu assert y.owner.inputs[0].owner.inputs[4] is sigma - # [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))] - assert z.owner.inputs[0].owner.inputs[3] is y - # [Det(y), Det(y)], not [Det(y), Det(Det(y))] - assert det_y_.owner.inputs[0] is y - assert det_y__.owner.inputs[0] is y assert det_y_ is not det_y__ + assert det_y_.owner.inputs[0] is y + if not inline_views: + # FreeRV(y(mu, sigma)) not FreeRV(y(Det(mu), Det(sigma))) + assert y.owner.inputs[0].owner.inputs[3] is mu + # FreeRV(z(y)) not FreeRV(z(Det(Det(y)))) + assert z.owner.inputs[0].owner.inputs[3] is y + # Det(y), not Det(Det(y)) + assert det_y__.owner.inputs[0] is y + else: + assert y.owner.inputs[0].owner.inputs[3] is det_mu + assert z.owner.inputs[0].owner.inputs[3] is det_y__ + assert det_y__.owner.inputs[0] is det_y_ # Both mu and sigma deterministics are now in the graph of x to y m = model_from_fgraph(fg) diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py index c2ec84ca..455ea633 100644 --- a/pymc_experimental/utils/model_fgraph.py +++ b/pymc_experimental/utils/model_fgraph.py @@ -90,14 +90,16 @@ def model_free_rv(rv, value, transform, *dims): def toposort_replace( - fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]] + fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False ) -> None: """Replace multiple variables in topological order.""" toposort = fgraph.toposort() sorted_replacements = sorted( - replacements, key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1 + replacements, + key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, + reverse=reverse, ) - fgraph.replace_all(tuple(sorted_replacements), import_missing=True) + fgraph.replace_all(sorted_replacements, import_missing=True) @node_rewriter([Elemwise]) @@ -109,11 +111,20 @@ def local_remove_identity(fgraph, node): remove_identity_rewrite = out2in(local_remove_identity) -def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Variable]]: +def fgraph_from_model( + model: Model, inlined_views=False +) -> Tuple[FunctionGraph, Dict[Variable, Variable]]: """Convert Model to FunctionGraph. See: model_from_fgraph + Parameters + ---------- + model: PyMC model + inlined_views: bool, default False + Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph, + or show up as separate branches. + Returns ------- fgraph: FunctionGraph @@ -138,19 +149,36 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia free_rvs = model.free_RVs observed_rvs = model.observed_RVs potentials = model.potentials + named_vars = model.named_vars.values() # We copy Deterministics (Identity Op) so that they don't show in between "main" variables # We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator old_deterministics = model.deterministics - deterministics = [det.copy(det.name) for det in old_deterministics] - # Other variables that are in model.named_vars but are not any of the categories above + deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics] + # Value variables (we also have to decide whether to inline named ones) + old_value_vars = list(rvs_to_values.values()) + unnamed_value_vars = [val for val in old_value_vars if val not in named_vars] + named_value_vars = [ + val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars + ] + value_vars = old_value_vars.copy() + if inlined_views: + # In this case we want to use the named_value_vars as the value_vars in RVs + for named_val in named_value_vars: + idx = value_vars.index(named_val) + value_vars[idx] = named_val + # Other variables that are in named_vars but are not any of the categories above # E.g., MutableData, ConstantData, _dim_lengths # We use the same trick as deterministics! - accounted_for = free_rvs + observed_rvs + potentials + old_deterministics - old_other_named_vars = [var for var in model.named_vars.values() if var not in accounted_for] - other_named_vars = [var.copy(var.name) for var in old_other_named_vars] - value_vars = [val for val in rvs_to_values.values() if val not in old_other_named_vars] + accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars) + other_named_vars = [ + var if inlined_views else var.copy(var.name) + for var in named_vars + if var not in accounted_for + ] - model_vars = rvs + potentials + deterministics + other_named_vars + value_vars + model_vars = ( + rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars + ) memo = {} @@ -176,13 +204,13 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia # Introduce dummy `ModelVar` Ops free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()} - free_rvs_to_values = {memo[k]: memo[v] for k, v in rvs_to_values.items() if k in free_rvs} + free_rvs_to_values = {memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in free_rvs} observed_rvs_to_values = { - memo[k]: memo[v] for k, v in rvs_to_values.items() if k in observed_rvs + memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in observed_rvs } potentials = [memo[k] for k in potentials] deterministics = [memo[k] for k in deterministics] - other_named_vars = [memo[k] for k in other_named_vars] + named_vars = [memo[k] for k in other_named_vars + named_value_vars] vars = fgraph.outputs new_vars = [] @@ -198,31 +226,31 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia new_var = model_potential(var, *dims) elif var in deterministics: new_var = model_deterministic(var, *dims) - elif var in other_named_vars: + elif var in named_vars: new_var = model_named(var, *dims) else: - # Value variables + # Unnamed value variables new_var = var new_vars.append(new_var) replacements = tuple(zip(vars, new_vars)) - toposort_replace(fgraph, replacements) + toposort_replace(fgraph, replacements, reverse=True) # Reference model vars in memo inverse_memo = {v: k for k, v in memo.items()} for var, model_var in replacements: - if isinstance( - model_var.owner is not None and model_var.owner.op, (ModelDeterministic, ModelNamed) + if not inlined_views and ( + model_var.owner and isinstance(model_var.owner.op, (ModelDeterministic, ModelNamed)) ): # Ignore extra identity that will be removed at the end var = var.owner.inputs[0] original_var = inverse_memo[var] memo[original_var] = model_var - # Remove value variable as outputs, now that they are graph inputs - first_value_idx = len(fgraph.outputs) - len(value_vars) - for _ in value_vars: - fgraph.remove_output(first_value_idx) + # Remove the last outputs corresponding to unnamed value variables, now that they are graph inputs + first_idx_to_remove = len(fgraph.outputs) - len(unnamed_value_vars) + for _ in unnamed_value_vars: + fgraph.remove_output(first_idx_to_remove) # Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph remove_identity_rewrite.apply(fgraph) From 52706e835fb2321eed8eefd9da357a11674eff9f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 17 May 2023 15:41:15 +0200 Subject: [PATCH 3/4] Implement observe and do model transformations --- docs/api_reference.rst | 11 + pymc_experimental/model_transform/__init__.py | 0 .../model_transform/conditioning.py | 199 ++++++++++++++++++ .../tests/model_transform/__init__.py | 0 .../model_transform/test_conditioning.py | 186 ++++++++++++++++ pymc_experimental/utils/model_fgraph.py | 11 + pymc_experimental/utils/pytensorf.py | 27 +++ 7 files changed, 434 insertions(+) create mode 100644 pymc_experimental/model_transform/__init__.py create mode 100644 pymc_experimental/model_transform/conditioning.py create mode 100644 pymc_experimental/tests/model_transform/__init__.py create mode 100644 pymc_experimental/tests/model_transform/test_conditioning.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index d73cc3ba..0d5409f4 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -35,6 +35,17 @@ Distributions histogram_approximation +Model Transformations +===================== + +.. currentmodule:: pymc_experimental.model_transform +.. autosummary:: + :toctree: generated/ + + conditioning.do + conditioning.observe + + Utils ===== diff --git a/pymc_experimental/model_transform/__init__.py b/pymc_experimental/model_transform/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/model_transform/conditioning.py b/pymc_experimental/model_transform/conditioning.py new file mode 100644 index 00000000..80434e46 --- /dev/null +++ b/pymc_experimental/model_transform/conditioning.py @@ -0,0 +1,199 @@ +from typing import Any, Dict, List, Sequence, Union + +from pymc import Model +from pymc.pytensorf import _replace_vars_in_graphs +from pytensor.tensor import TensorVariable + +from pymc_experimental.utils.model_fgraph import ( + ModelDeterministic, + ModelFreeRV, + extract_dims, + fgraph_from_model, + model_deterministic, + model_from_fgraph, + model_named, + model_observed_rv, + toposort_replace, +) +from pymc_experimental.utils.pytensorf import rvs_in_graph + + +def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model: + """Convert free RVs or Deterministics to observed RVs. + + Parameters + ---------- + model: PyMC Model + vars_to_observations: Dict of variable or name to TensorLike + Dictionary that maps model variables (or names) to observed values. + Observed values must have a shape and data type that is compatible + with the original model variable. + + Returns + ------- + new_model: PyMC model + A distinct PyMC model with the relevant variables observed. + All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. + + Examples + -------- + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import observe + + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal("y", x) + z = pm.Normal("z", y) + + m_new = observe(m, {y: 0.5}) + + Deterministic variables can also be observed. + This relies on PyMC ability to infer the logp of the underlying expression + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import observe + + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal.dist(x, shape=(5,)) + y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) + + new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]}) + + + """ + vars_to_observations = { + model[var] if isinstance(var, str) else var: obs + for var, obs in vars_to_observations.items() + } + + valid_model_vars = set(model.free_RVs + model.deterministics) + if any(var not in valid_model_vars for var in vars_to_observations): + raise ValueError(f"At least one var is not a free variable or deterministic in the model") + + fgraph, memo = fgraph_from_model(model) + + replacements = {} + for var, obs in vars_to_observations.items(): + model_var = memo[var] + + # Just a sanity check + assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic)) + assert model_var in fgraph.variables + + var = model_var.owner.inputs[0] + var.name = model_var.name + dims = extract_dims(model_var) + model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims) + replacements[model_var] = model_obs_rv + + toposort_replace(fgraph, tuple(replacements.items())) + + return model_from_fgraph(fgraph) + + +def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]: + def replacement_fn(var, inner_replacements): + if var in replacements: + inner_replacements[var] = replacements[var] + + # Handle root inputs as those will never be passed to the replacement_fn + for inp in var.owner.inputs: + if inp.owner is None and inp in replacements: + inner_replacements[inp] = replacements[inp] + + return [var] + + replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn) + return replaced_graphs + + +def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model: + """Replace model variables by intervention variables. + + Intervention variables will either show up as `Data` or `Deterministics` in the new model, + depending on whether they depend on other RandomVariables or not. + + Parameters + ---------- + model: PyMC Model + vars_to_interventions: Dict of variable or name to TensorLike + Dictionary that maps model variables (or names) to intervention expressions. + Intervention expressions must have a shape and data type that is compatible + with the original model variable. + + Returns + ------- + new_model: PyMC model + A distinct PyMC model with the relevant variables replaced by the intervention expressions. + All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. + + Examples + -------- + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import do + + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + y = pm.Normal("y", x, 1) + z = pm.Normal("z", y + x, 1) + + # Dummy posterior, same as calling `pm.sample` + idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]}) + + # Replace `y` by a constant `100.0` + m_do = do(m, {y: 100.0}) + with m_do: + idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") + + """ + do_mapping = {} + for var, obs in vars_to_interventions.items(): + if isinstance(var, str): + var = model[var] + try: + do_mapping[var] = var.type.filter_variable(obs) + except TypeError as err: + raise TypeError( + "Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables" + ) from err + + if any(var not in model.named_vars.values() for var in do_mapping): + raise ValueError(f"At least one var is not a named variable in the model") + + fgraph, memo = fgraph_from_model(model, inlined_views=True) + + # We need the interventions defined in terms of the IR fgraph representation, + # In case they reference other variables in the model + ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo) + + replacements = {} + for var, intervention in zip(do_mapping, ir_interventions): + model_var = memo[var] + + # Just a sanity check + assert model_var in fgraph.variables + + intervention.name = model_var.name + dims = extract_dims(model_var) + # If there are any RVs in the graph we introduce the intervention as a deterministic + if rvs_in_graph([intervention]): + new_var = model_deterministic(intervention.copy(name=intervention.name), *dims) + # Otherwise as a named variable (Constant or Shared data) + else: + new_var = model_named(intervention, *dims) + + replacements[model_var] = new_var + + # Replace variables by interventions + toposort_replace(fgraph, tuple(replacements.items())) + + return model_from_fgraph(fgraph) diff --git a/pymc_experimental/tests/model_transform/__init__.py b/pymc_experimental/tests/model_transform/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/tests/model_transform/test_conditioning.py b/pymc_experimental/tests/model_transform/test_conditioning.py new file mode 100644 index 00000000..fbc66d6a --- /dev/null +++ b/pymc_experimental/tests/model_transform/test_conditioning.py @@ -0,0 +1,186 @@ +import arviz as az +import numpy as np +import pymc as pm +import pytest +from pymc.variational.minibatch_rv import create_minibatch_rv +from pytensor import config + +from pymc_experimental.model_transform.conditioning import do, observe + + +def test_observe(): + with pm.Model() as m_old: + x = pm.Normal("x") + y = pm.Normal("y", x) + z = pm.Normal("z", y) + + m_new = observe(m_old, {y: 0.5}) + + assert len(m_new.free_RVs) == 2 + assert len(m_new.observed_RVs) == 1 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.observed_RVs + assert m_new["z"] in m_new.free_RVs + + np.testing.assert_allclose( + m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}), + m_new.compile_logp()({"x": 0.9, "z": 1.4}), + ) + + # Test two substitutions + m_new = observe(m_old, {y: 0.5, z: 1.4}) + + assert len(m_new.free_RVs) == 1 + assert len(m_new.observed_RVs) == 2 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.observed_RVs + assert m_new["z"] in m_new.observed_RVs + + np.testing.assert_allclose( + m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}), + m_new.compile_logp()({"x": 0.9}), + ) + + +def test_observe_minibatch(): + data = np.zeros((100,), dtype=config.floatX) + batch_size = 10 + with pm.Model() as m_old: + x = pm.Normal("x") + y = pm.Normal("y", x) + # Minibatch RVs are usually created with `total_size` kwarg + z_raw = pm.Normal.dist(y, shape=batch_size) + mb_z = create_minibatch_rv(z_raw, total_size=data.shape) + m_old.register_rv(mb_z, name="mb_z") + + mb_data = pm.Minibatch(data, batch_size=batch_size) + m_new = observe(m_old, {mb_z: mb_data}) + + assert len(m_new.free_RVs) == 2 + assert len(m_new.observed_RVs) == 1 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.free_RVs + assert m_new["mb_z"] in m_new.observed_RVs + + np.testing.assert_allclose( + m_old.compile_logp()({"x": 0.9, "y": 0.5, "mb_z": np.zeros(10)}), + m_new.compile_logp()({"x": 0.9, "y": 0.5}), + ) + + +def test_observe_deterministic(): + y_censored_obs = np.array([0.9, 0.5, 0.3, 1, 1], dtype=config.floatX) + + with pm.Model() as m_old: + x = pm.Normal("x") + y = pm.Normal.dist(x, shape=(5,)) + y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) + + m_new = observe(m_old, {y_censored: y_censored_obs}) + + with pm.Model() as m_ref: + x = pm.Normal("x") + pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs) + + +def test_observe_dims(): + with pm.Model(coords={"test_dim": range(5)}) as m_old: + x = pm.Normal("x", dims="test_dim") + + m_new = observe(m_old, {x: np.arange(5, dtype=config.floatX)}) + assert m_new.named_vars_to_dims["x"] == ["test_dim"] + + +def test_do(): + with pm.Model() as m_old: + x = pm.Normal("x", 0, 1e-3) + y = pm.Normal("y", x, 1e-3) + z = pm.Normal("z", y + x, 1e-3) + + assert -5 < pm.draw(z) < 5 + + m_new = do(m_old, {y: x + 100}) + + assert len(m_new.free_RVs) == 2 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.deterministics + assert m_new["z"] in m_new.free_RVs + + assert 95 < pm.draw(m_new["z"]) < 105 + + # Test two substitutions + with m_old: + switch = pm.MutableData("switch", 1) + m_new = do(m_old, {y: 100 * switch, x: 100 * switch}) + + assert len(m_new.free_RVs) == 1 + assert m_new["y"] not in m_new.deterministics + assert m_new["x"] not in m_new.deterministics + assert m_new["z"] in m_new.free_RVs + + assert 195 < pm.draw(m_new["z"]) < 205 + with m_new: + pm.set_data({"switch": 0}) + assert -5 < pm.draw(m_new["z"]) < 5 + + +def test_do_posterior_predictive(): + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + y = pm.Normal("y", x, 1) + z = pm.Normal("z", y + x, 1e-3) + + # Dummy posterior + idata_m = az.from_dict( + { + "x": np.full((2, 500), 25), + "y": np.full((2, 500), np.nan), + "z": np.full((2, 500), np.nan), + } + ) + + # Replace `y` by a constant `100.0` + m_do = do(m, {y: 100.0}) + with m_do: + idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") + + assert 120 < idata_do.posterior_predictive["z"].mean() < 130 + + +@pytest.mark.parametrize("mutable", (False, True)) +def test_do_constant(mutable): + with pm.Model() as m: + x = pm.Data("x", 0, mutable=mutable) + y = pm.Normal("y", x, 1e-3) + + do_m = do(m, {x: 105}) + assert pm.draw(do_m["y"]) > 100 + + +def test_do_deterministic(): + with pm.Model() as m: + x = pm.Normal("x", 0, 1e-3) + y = pm.Deterministic("y", x + 105) + z = pm.Normal("z", y, 1e-3) + + do_m = do(m, {"z": x - 105}) + assert pm.draw(do_m["z"]) < 100 + + +def test_do_dims(): + coords = {"test_dim": range(10)} + with pm.Model(coords=coords) as m: + x = pm.Normal("x", dims="test_dim") + y = pm.Deterministic("y", x + 5, dims="test_dim") + + do_m = do( + m, + {"x": np.zeros(10, dtype=config.floatX)}, + ) + assert do_m.named_vars_to_dims["x"] == ["test_dim"] + + do_m = do( + m, + {"y": np.zeros(10, dtype=config.floatX)}, + ) + assert do_m.named_vars_to_dims["y"] == ["test_dim"] diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py index 455ea633..b51f44d6 100644 --- a/pymc_experimental/utils/model_fgraph.py +++ b/pymc_experimental/utils/model_fgraph.py @@ -358,3 +358,14 @@ def clone_model(model: Model) -> Model: """ return model_from_fgraph(fgraph_from_model(model)[0]) + + +def extract_dims(var) -> Tuple: + dims = () + node = var.owner + if node and isinstance(node.op, ModelVar): + if isinstance(node.op, ModelValuedVar): + dims = node.inputs[2:] + else: + dims = node.inputs[1:] + return dims diff --git a/pymc_experimental/utils/pytensorf.py b/pymc_experimental/utils/pytensorf.py index 76358c27..a953b5c1 100644 --- a/pymc_experimental/utils/pytensorf.py +++ b/pymc_experimental/utils/pytensorf.py @@ -1,5 +1,12 @@ +from typing import Sequence + import pytensor +from pymc import SymbolicRandomVariable +from pytensor import Variable from pytensor.graph import Constant, Type +from pytensor.graph.basic import walk +from pytensor.graph.op import HasInnerGraph +from pytensor.tensor.random.op import RandomVariable class StringType(Type[str]): @@ -31,3 +38,23 @@ class StringConstant(Constant): def as_symbolic_string(x, **kwargs): return StringConstant(stringtype, x) + + +def rvs_in_graph(vars: Sequence[Variable]) -> bool: + """Check if there are any rvs in the graph of vars""" + + def expand(r): + owner = r.owner + if owner: + inputs = list(reversed(owner.inputs)) + + if isinstance(owner.op, HasInnerGraph): + inputs += owner.op.inner_outputs + + return inputs + + return any( + node + for node in walk(vars, expand, False) + if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable)) + ) From 4deb91abbf50e2792b590e7c4689155c38f10bd0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Jun 2023 14:39:26 +0200 Subject: [PATCH 4/4] Add option to prune variables after do intervention --- pymc_experimental/model_transform/basic.py | 35 +++++++++++++++++++ .../model_transform/conditioning.py | 13 +++++-- .../tests/model_transform/test_basic.py | 19 ++++++++++ .../model_transform/test_conditioning.py | 27 ++++++++++++++ 4 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 pymc_experimental/model_transform/basic.py create mode 100644 pymc_experimental/tests/model_transform/test_basic.py diff --git a/pymc_experimental/model_transform/basic.py b/pymc_experimental/model_transform/basic.py new file mode 100644 index 00000000..23e10ceb --- /dev/null +++ b/pymc_experimental/model_transform/basic.py @@ -0,0 +1,35 @@ +from pymc import Model +from pytensor.graph import ancestors + +from pymc_experimental.utils.model_fgraph import ( + ModelObservedRV, + ModelVar, + fgraph_from_model, + model_from_fgraph, +) + + +def prune_vars_detached_from_observed(model: Model) -> Model: + """Prune model variables that are not related to any observed variable in the Model.""" + + # Potentials are ambiguous as whether they correspond to likelihood or prior terms, + # We simply raise for now + if model.potentials: + raise NotImplementedError("Pruning not implemented for models with Potentials") + + fgraph, _ = fgraph_from_model(model, inlined_views=True) + observed_vars = ( + out + for node in fgraph.apply_nodes + if isinstance(node.op, ModelObservedRV) + for out in node.outputs + ) + ancestor_nodes = {var.owner for var in ancestors(observed_vars)} + nodes_to_remove = { + node + for node in fgraph.apply_nodes + if isinstance(node.op, ModelVar) and node not in ancestor_nodes + } + for node_to_remove in nodes_to_remove: + fgraph.remove_node(node_to_remove) + return model_from_fgraph(fgraph) diff --git a/pymc_experimental/model_transform/conditioning.py b/pymc_experimental/model_transform/conditioning.py index 80434e46..fb4468c8 100644 --- a/pymc_experimental/model_transform/conditioning.py +++ b/pymc_experimental/model_transform/conditioning.py @@ -4,6 +4,7 @@ from pymc.pytensorf import _replace_vars_in_graphs from pytensor.tensor import TensorVariable +from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed from pymc_experimental.utils.model_fgraph import ( ModelDeterministic, ModelFreeRV, @@ -113,7 +114,9 @@ def replacement_fn(var, inner_replacements): return replaced_graphs -def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model: +def do( + model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False +) -> Model: """Replace model variables by intervention variables. Intervention variables will either show up as `Data` or `Deterministics` in the new model, @@ -126,6 +129,9 @@ def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], A Dictionary that maps model variables (or names) to intervention expressions. Intervention expressions must have a shape and data type that is compatible with the original model variable. + prune_vars: bool, defaults to False + Whether to prune model variables that are not connected to any observed variables, + after the interventions. Returns ------- @@ -196,4 +202,7 @@ def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], A # Replace variables by interventions toposort_replace(fgraph, tuple(replacements.items())) - return model_from_fgraph(fgraph) + model = model_from_fgraph(fgraph) + if prune_vars: + return prune_vars_detached_from_observed(model) + return model diff --git a/pymc_experimental/tests/model_transform/test_basic.py b/pymc_experimental/tests/model_transform/test_basic.py new file mode 100644 index 00000000..a2771d01 --- /dev/null +++ b/pymc_experimental/tests/model_transform/test_basic.py @@ -0,0 +1,19 @@ +import pymc as pm + +from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed + + +def test_prune_vars_detached_from_observed(): + with pm.Model() as m: + obs_data = pm.MutableData("obs_data", 0) + a0 = pm.ConstantData("a0", 0) + a1 = pm.Normal("a1", a0) + a2 = pm.Normal("a2", a1) + pm.Normal("obs", a2, observed=obs_data) + + d0 = pm.ConstantData("d0", 0) + d1 = pm.Normal("d1", d0) + + assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} + pruned_m = prune_vars_detached_from_observed(m) + assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} diff --git a/pymc_experimental/tests/model_transform/test_conditioning.py b/pymc_experimental/tests/model_transform/test_conditioning.py index fbc66d6a..9c5ff9b6 100644 --- a/pymc_experimental/tests/model_transform/test_conditioning.py +++ b/pymc_experimental/tests/model_transform/test_conditioning.py @@ -184,3 +184,30 @@ def test_do_dims(): {"y": np.zeros(10, dtype=config.floatX)}, ) assert do_m.named_vars_to_dims["y"] == ["test_dim"] + + +@pytest.mark.parametrize("prune", (False, True)) +def test_do_prune(prune): + + with pm.Model() as m: + x0 = pm.ConstantData("x0", 0) + x1 = pm.ConstantData("x1", 0) + y = pm.Normal("y") + y_det = pm.Deterministic("y_det", y + x0) + z = pm.Normal("z", y_det) + llike = pm.Normal("llike", z + x1, observed=0) + + orig_named_vars = {"x0", "x1", "y", "y_det", "z", "llike"} + assert set(m.named_vars) == orig_named_vars + + do_m = do(m, {y_det: x0 + 5}, prune_vars=prune) + if prune: + assert set(do_m.named_vars) == {"x0", "x1", "y_det", "z", "llike"} + else: + assert set(do_m.named_vars) == orig_named_vars + + do_m = do(m, {z: 0.5}, prune_vars=prune) + if prune: + assert set(do_m.named_vars) == {"x1", "z", "llike"} + else: + assert set(do_m.named_vars) == orig_named_vars