Skip to content

Commit cec2963

Browse files
committed
Implement observe and do model transformations
1 parent 3bdbe4c commit cec2963

File tree

6 files changed

+308
-0
lines changed

6 files changed

+308
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ Distributions
3535
histogram_approximation
3636

3737

38+
Model Transformations
39+
=====================
40+
41+
.. currentmodule:: pymc_experimental.model_transform
42+
.. autosummary::
43+
:toctree: generated/
44+
45+
conditioning.do
46+
conditioning.observe
47+
48+
3849
Utils
3950
=====
4051

pymc_experimental/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from typing import Any, Dict, List, Sequence, Union
2+
3+
from pymc import Model
4+
from pymc.pytensorf import _replace_vars_in_graphs
5+
from pytensor.tensor import TensorVariable
6+
7+
from pymc_experimental.utils.model_fgraph import (
8+
ModelFreeRV,
9+
extract_dims,
10+
fgraph_from_model,
11+
model_from_fgraph,
12+
model_named,
13+
model_observed_rv,
14+
toposort_replace,
15+
)
16+
17+
18+
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
19+
"""Convert free RVs to observed RVs.
20+
21+
Parameters
22+
----------
23+
model: PyMC Model
24+
vars_to_observations: Dict of variable or name to TensorLike
25+
Dictionary that maps model variables (or names) to observed values.
26+
Observed values must have a shape and data type that is compatible
27+
with the original model variable.
28+
29+
Returns
30+
-------
31+
new_model: PyMC model
32+
A distinct PyMC model with the relevant variables observed.
33+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
34+
35+
Examples
36+
--------
37+
38+
.. code-block:: python
39+
40+
import pymc as pm
41+
from pymc_experimental.model_transform.conditioning import observe
42+
43+
with pm.Model() as m:
44+
x = pm.Normal("x")
45+
y = pm.Normal("y", x)
46+
z = pm.Normal("z", y)
47+
48+
m_new = observe(m, {y: 0.5})
49+
50+
"""
51+
vars_to_observations = {
52+
model[var] if isinstance(var, str) else var: obs
53+
for var, obs in vars_to_observations.items()
54+
}
55+
56+
# Note: Since PyMC can infer logprob expressions we could also allow observing Deterministics
57+
if any(var not in model.free_RVs for var in vars_to_observations):
58+
raise ValueError(f"At least one var is not a free variable in the model")
59+
60+
fgraph, memo = fgraph_from_model(model)
61+
62+
replacements = {}
63+
for var, obs in vars_to_observations.items():
64+
model_free_rv = memo[var]
65+
66+
# Just a sanity check
67+
assert isinstance(model_free_rv.owner.op, ModelFreeRV)
68+
assert model_free_rv in fgraph.variables
69+
70+
rv, vv, *dims = model_free_rv.owner.inputs
71+
model_obs_rv = model_observed_rv(rv, rv.type.filter_variable(obs), *dims)
72+
replacements[model_free_rv] = model_obs_rv
73+
74+
toposort_replace(fgraph, tuple(replacements.items()))
75+
76+
return model_from_fgraph(fgraph)
77+
78+
79+
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]:
80+
def replacement_fn(var, inner_replacements):
81+
if var in replacements:
82+
inner_replacements[var] = replacements[var]
83+
84+
# Handle root inputs as those will never be passed to the replacement_fn
85+
for inp in var.owner.inputs:
86+
if inp.owner is None and inp in replacements:
87+
inner_replacements[inp] = replacements[inp]
88+
89+
return [var]
90+
91+
replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn)
92+
return replaced_graphs
93+
94+
95+
def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
96+
"""Replace model variables by intervention variables.
97+
98+
Parameters
99+
----------
100+
model: PyMC Model
101+
vars_to_interventions: Dict of variable or name to TensorLike
102+
Dictionary that maps model variables (or names) to intervention expressions.
103+
Intervention expressions must have a shape and data type that is compatible
104+
with the original model variable.
105+
106+
Returns
107+
-------
108+
new_model: PyMC model
109+
A distinct PyMC model with the relevant variables replaced by the intervention expressions.
110+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
111+
112+
Examples
113+
--------
114+
115+
.. code-block:: python
116+
117+
import pymc as pm
118+
from pymc_experimental.model_transform.conditioning import do
119+
120+
with pm.Model() as m:
121+
x = pm.Normal("x", 0, 1)
122+
y = pm.Normal("y", x, 1)
123+
z = pm.Normal("z", y + x, 1)
124+
125+
# Dummy posterior, same as calling `pm.sample`
126+
idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]})
127+
128+
# Replace `y` by a constant `100.0`
129+
m_do = do(m, {y: 100.0})
130+
with m_do:
131+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
132+
133+
"""
134+
do_mapping = {}
135+
for var, obs in vars_to_interventions.items():
136+
if isinstance(var, str):
137+
var = model[var]
138+
do_mapping[var] = var.type.filter_variable(obs)
139+
140+
if any(var not in (model.basic_RVs + model.deterministics) for var in do_mapping):
141+
raise ValueError(f"At least one var is not a variable or deterministic in the model")
142+
143+
fgraph, memo = fgraph_from_model(model)
144+
145+
# We need the interventions defined in terms of the IR fgraph representation,
146+
# In case they reference other variables in the model
147+
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo)
148+
149+
replacements = {}
150+
for var, intervention in zip(do_mapping, ir_interventions):
151+
model_var = memo[var]
152+
153+
# Just a sanity check
154+
assert model_var in fgraph.variables
155+
156+
intervention.name = model_var.name
157+
dims = extract_dims(model_var)
158+
new_var = model_named(intervention, *dims)
159+
160+
replacements[model_var] = new_var
161+
162+
# Replace variables by interventions
163+
toposort_replace(fgraph, tuple(replacements.items()))
164+
165+
return model_from_fgraph(fgraph)

pymc_experimental/tests/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc as pm
4+
from pymc.variational.minibatch_rv import create_minibatch_rv
5+
6+
from pymc_experimental.model_transform.conditioning import do, observe
7+
8+
9+
def test_observe():
10+
with pm.Model() as m_old:
11+
x = pm.Normal("x")
12+
y = pm.Normal("y", x)
13+
z = pm.Normal("z", y)
14+
15+
m_new = observe(m_old, {y: 0.5})
16+
17+
assert len(m_new.free_RVs) == 2
18+
assert len(m_new.observed_RVs) == 1
19+
assert m_new["x"] in m_new.free_RVs
20+
assert m_new["y"] in m_new.observed_RVs
21+
assert m_new["z"] in m_new.free_RVs
22+
23+
np.testing.assert_allclose(
24+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
25+
m_new.compile_logp()({"x": 0.9, "z": 1.4}),
26+
)
27+
28+
# Test two substitutions
29+
m_new = observe(m_old, {y: 0.5, z: 1.4})
30+
31+
assert len(m_new.free_RVs) == 1
32+
assert len(m_new.observed_RVs) == 2
33+
assert m_new["x"] in m_new.free_RVs
34+
assert m_new["y"] in m_new.observed_RVs
35+
assert m_new["z"] in m_new.observed_RVs
36+
37+
np.testing.assert_allclose(
38+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
39+
m_new.compile_logp()({"x": 0.9}),
40+
)
41+
42+
43+
def test_observe_minibatch():
44+
data = np.zeros((100,))
45+
batch_size = (10,)
46+
with pm.Model() as m_old:
47+
x = pm.Normal("x")
48+
y = pm.Normal("y", x)
49+
# Minibatch RVs are usually created with `total_size` kwarg
50+
z_raw = pm.Normal.dist(y, shape=batch_size)
51+
mb_z = create_minibatch_rv(z_raw, total_size=data.shape)
52+
m_old.register_rv(mb_z, name="mb_z")
53+
54+
mb_data = pm.Minibatch(data, batch_size=batch_size)
55+
m_new = observe(m_old, {mb_z: mb_data})
56+
57+
assert len(m_new.free_RVs) == 2
58+
assert len(m_new.observed_RVs) == 1
59+
assert m_new["x"] in m_new.free_RVs
60+
assert m_new["y"] in m_new.free_RVs
61+
assert m_new["mb_z"] in m_new.observed_RVs
62+
63+
np.testing.assert_allclose(
64+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "mb_z": np.zeros(10)}),
65+
m_new.compile_logp()({"x": 0.9, "y": 0.5}),
66+
)
67+
68+
69+
def test_do():
70+
with pm.Model() as m_old:
71+
x = pm.Normal("x", 0, 1e-3)
72+
y = pm.Normal("y", x, 1e-3)
73+
z = pm.Normal("z", y + x, 1e-3)
74+
75+
assert -5 < pm.draw(z) < 5
76+
77+
m_new = do(m_old, {y: x + 100})
78+
79+
assert len(m_new.free_RVs) == 2
80+
assert m_new["x"] in m_new.free_RVs
81+
assert m_new["y"] in m_new.named_vars.values()
82+
assert m_new["z"] in m_new.free_RVs
83+
84+
assert 95 < pm.draw(m_new["z"]) < 105
85+
86+
# Test two substitutions
87+
with m_old:
88+
switch = pm.MutableData("switch", 1)
89+
m_new = do(m_old, {y: 100 * switch, x: 100 * switch})
90+
91+
assert len(m_new.free_RVs) == 1
92+
assert m_new["x"] in m_new.named_vars.values()
93+
assert m_new["y"] in m_new.named_vars.values()
94+
assert m_new["z"] in m_new.free_RVs
95+
96+
assert 195 < pm.draw(m_new["z"]) < 205
97+
with m_new:
98+
pm.set_data({"switch": 0})
99+
assert -5 < pm.draw(m_new["z"]) < 5
100+
101+
102+
def test_do_posterior_predictive():
103+
with pm.Model() as m:
104+
x = pm.Normal("x", 0, 1)
105+
y = pm.Normal("y", x, 1)
106+
z = pm.Normal("z", y + x, 1e-3)
107+
108+
# Dummy posterior
109+
idata_m = az.from_dict(
110+
{
111+
"x": np.full((2, 500), 25),
112+
"y": np.full((2, 500), np.nan),
113+
"z": np.full((2, 500), np.nan),
114+
}
115+
)
116+
117+
# Replace `y` by a constant `100.0`
118+
m_do = do(m, {y: 100.0})
119+
with m_do:
120+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
121+
122+
np.testing.assert_allclose(idata_do.posterior_predictive["z"].mean(), 125)

pymc_experimental/utils/model_fgraph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,13 @@ def clone_model(model: Model) -> Tuple[Model]:
326326
327327
"""
328328
return model_from_fgraph(fgraph_from_model(model)[0])
329+
330+
331+
def extract_dims(var) -> Tuple:
332+
dims = ()
333+
if isinstance(var, ModelVar):
334+
if isinstance(var, ModelValuedVar):
335+
dims = var.inputs[2:]
336+
else:
337+
dims = var.inputs[1:]
338+
return dims

0 commit comments

Comments
 (0)