Skip to content

Commit 5e1c8b8

Browse files
committed
Implement utility to convert PyMC modelt to and from FunctionGraph
1 parent 6f67dec commit 5e1c8b8

File tree

4 files changed

+353
-0
lines changed

4 files changed

+353
-0
lines changed

pymc_experimental/tests/utils/__init__.py

Whitespace-only changes.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor.tensor as pt
4+
import pytest
5+
from pytensor.graph import FunctionGraph, node_rewriter
6+
from pytensor.graph.rewriting.basic import in2out
7+
from pytensor.tensor.exceptions import NotScalarConstantError
8+
9+
from pymc_experimental.utils.model_fgraph import (
10+
FreeRV,
11+
deterministic,
12+
fgraph_from_model,
13+
free_rv,
14+
model_from_fgraph,
15+
)
16+
17+
18+
def test_model_fgraph_conversion():
19+
"""Test we can convert from a PyMC Model to a FunctionGraph and back"""
20+
with pm.Model(coords={"test_dim": range(3)}) as m_old:
21+
x = pm.Normal("x")
22+
y = pm.Deterministic("y", x + 1)
23+
w = pm.Normal("w", y)
24+
z = pm.Normal("z", y, observed=[0, 1, 2], dims=("test_dim",))
25+
pm.Potential("pot", x * 2)
26+
27+
m_fgraph = fgraph_from_model(m_old)
28+
assert isinstance(m_fgraph, FunctionGraph)
29+
30+
m_new = model_from_fgraph(m_fgraph)
31+
assert isinstance(m_new, pm.Model)
32+
33+
assert m_new.coords == {"test_dim": tuple(range(3))}
34+
assert m_new.named_vars_to_dims == {"z": ["test_dim"]}
35+
36+
named_vars = {"x", "y", "w", "z", "pot"}
37+
assert set(m_new.named_vars) == named_vars
38+
for named_var in named_vars:
39+
assert m_new[named_var] is not m_old[named_var]
40+
assert m_new["x"] in m_new.free_RVs
41+
assert m_new["w"] in m_new.free_RVs
42+
assert m_new["y"] in m_new.deterministics
43+
assert m_new["z"] in m_new.observed_RVs
44+
assert m_new["pot"] in m_new.potentials
45+
46+
new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1)
47+
old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1)
48+
np.testing.assert_array_equal(new_y_draw, old_y_draw)
49+
np.testing.assert_array_equal(new_z_draw, old_z_draw)
50+
51+
ip = m_new.initial_point()
52+
np.testing.assert_equal(
53+
m_new.compile_logp()(ip),
54+
m_old.compile_logp()(ip),
55+
)
56+
57+
58+
@pytest.fixture()
59+
def non_centered_rewrite():
60+
@node_rewriter(tracks=[FreeRV])
61+
def non_centered_param(fgraph: FunctionGraph, node):
62+
"""Rewrite that replaces centered normal by non-centered parametrization."""
63+
64+
rv, _, *dims = node.inputs
65+
if not isinstance(rv.owner.op, pm.Normal):
66+
return
67+
rng, size, dtype, loc, scale = rv.owner.inputs
68+
69+
# Only apply rewrite if size information is explicit
70+
if size.ndim == 0:
71+
return None
72+
73+
try:
74+
is_unit = (
75+
pt.get_scalar_constant_value(loc) == 0 and pt.get_scalar_constant_value(scale) == 1
76+
)
77+
except NotScalarConstantError:
78+
is_unit = False
79+
80+
# Nothing to do here
81+
if is_unit:
82+
return
83+
84+
raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng)
85+
raw_norm.name = f"{rv.name}_raw_"
86+
raw_norm_value = raw_norm.clone()
87+
fgraph.add_input(raw_norm_value)
88+
raw_norm = free_rv(raw_norm, raw_norm_value, dims=dims)
89+
90+
new_norm = loc + raw_norm * scale
91+
new_norm.name = rv.name
92+
new_norm = deterministic(new_norm, dims=dims)
93+
94+
return [new_norm]
95+
96+
return in2out(non_centered_param)
97+
98+
99+
def test_fgraph_rewrite(non_centered_rewrite):
100+
"""Test we can apply a simple rewrite to a PyMC Model."""
101+
102+
with pm.Model(coords={"subject": range(10)}) as m_old:
103+
group_mean = pm.Normal("group_mean")
104+
# FIXME: value transforms are not yet maintained across conversion
105+
group_std = pm.HalfNormal("group_std", transform=None)
106+
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
107+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
108+
109+
fg = fgraph_from_model(m_old)
110+
non_centered_rewrite.apply(fg)
111+
112+
m_new = model_from_fgraph(fg)
113+
assert m_new.named_vars_to_dims == {
114+
"subject_mean": ["subject"],
115+
"subject_mean_raw_": ["subject"],
116+
"obs": ["subject"],
117+
}
118+
assert set(m_new.named_vars) == {
119+
"group_mean",
120+
"group_std",
121+
"subject_mean_raw_",
122+
"subject_mean",
123+
"obs",
124+
}
125+
assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"}
126+
assert {rv.name for rv in m_new.observed_RVs} == {"obs"}
127+
assert {rv.name for rv in m_new.deterministics} == {"subject_mean"}
128+
129+
with pm.Model() as m_ref:
130+
group_mean = pm.Normal("group_mean")
131+
# FIXME: value transforms are not yet maintained across conversion
132+
group_std = pm.HalfNormal("group_std", transform=None)
133+
subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,))
134+
subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std)
135+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10))
136+
137+
np.testing.assert_array_equal(
138+
pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1),
139+
pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1),
140+
)
141+
142+
ip = m_new.initial_point()
143+
np.testing.assert_equal(
144+
m_new.compile_logp()(ip),
145+
m_ref.compile_logp()(ip),
146+
)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from typing import Optional, Sequence
2+
3+
import pytensor
4+
from pymc.model import Model
5+
from pytensor.graph import Apply, FunctionGraph, Op
6+
from pytensor.tensor import TensorVariable
7+
8+
from pymc_experimental.utils.pytensorf import StringType
9+
10+
11+
class ModelVar(Op):
12+
"""A dummy Op that describes the purpose of a Model variable and contains
13+
meta-information as additional inputs (value and dims).
14+
"""
15+
16+
def make_node(self, rv, value=None, dims: Optional[Sequence[str]] = None):
17+
assert isinstance(rv, TensorVariable)
18+
19+
if dims is not None:
20+
dims = [pytensor.as_symbolic(dim) for dim in dims]
21+
assert all(isinstance(dim.type, StringType) for dim in dims)
22+
assert len(dims) == rv.type.ndim
23+
else:
24+
dims = ()
25+
26+
if value is not None:
27+
assert isinstance(value, TensorVariable)
28+
assert rv.type.in_same_class(value.type)
29+
return Apply(self, [rv, value, *dims], [rv.type()])
30+
else:
31+
return Apply(self, [rv, *dims], [rv.type()])
32+
33+
def infer_shape(self, fgraph, node, inputs_shape):
34+
return inputs_shape[0]
35+
36+
def do_constant_folding(self, fgraph, node):
37+
return False
38+
39+
def perform(self, *args, **kwargs):
40+
raise RuntimeError("ValuedRVs should never be evaluated!")
41+
42+
43+
class FreeRV(ModelVar):
44+
pass
45+
46+
47+
class ObservedRV(ModelVar):
48+
pass
49+
50+
51+
class Potential(ModelVar):
52+
pass
53+
54+
55+
class Deterministic(ModelVar):
56+
pass
57+
58+
59+
free_rv = FreeRV()
60+
observed_rv = ObservedRV()
61+
potential = Potential()
62+
deterministic = Deterministic()
63+
64+
65+
def toposort_replace(fgraph: FunctionGraph, replacements) -> None:
66+
"""Replace multiple variables in topological order."""
67+
toposort = fgraph.toposort()
68+
sorted_replacements = sorted(replacements, key=lambda pair: toposort.index(pair[0].owner))
69+
fgraph.replace_all(tuple(sorted_replacements), import_missing=True)
70+
71+
72+
def fgraph_from_model(model: Model) -> FunctionGraph:
73+
74+
# Collect PyTensor variables
75+
rvs_to_values = model.rvs_to_values
76+
rvs = list(rvs_to_values.keys())
77+
values = list(rvs_to_values.values())
78+
free_rvs = model.free_RVs
79+
deterministics = model.deterministics
80+
potentials = model.potentials
81+
82+
# Collect PyMC meta-info
83+
vars_to_dims = model.named_vars_to_dims
84+
coords = model.coords
85+
86+
# TODO: Do something with these
87+
dim_lengths = model.dim_lengths
88+
rvs_to_transforms = model.rvs_to_transforms
89+
90+
# Not supported yet
91+
if any(v is not None for v in model.rvs_to_total_sizes.values()):
92+
raise NotImplementedError("Cannot convert models with total_sizes")
93+
if any(v is not None for v in model.rvs_to_initial_values.values()):
94+
raise NotImplementedError("Cannot convert models with non-default initial_values")
95+
96+
# We start the `dict` with mappings from the value variables to themselves,
97+
# to prevent them from being cloned.
98+
memo = {v: v for v in values}
99+
100+
fgraph = FunctionGraph(
101+
outputs=rvs + potentials + deterministics,
102+
clone=True,
103+
memo=memo,
104+
copy_orphans=False,
105+
copy_inputs=False,
106+
)
107+
fgraph.coords = coords
108+
109+
# Introduce dummy `ModelVar` Ops
110+
free_rvs_to_values = {memo[k]: v for k, v in rvs_to_values.items() if k in free_rvs}
111+
observed_rvs_to_values = {memo[k]: v for k, v in rvs_to_values.items() if k not in free_rvs}
112+
potentials = [memo[k] for k in potentials]
113+
deterministics = [memo[k] for k in deterministics]
114+
115+
vars = fgraph.outputs
116+
new_vars = []
117+
for var in vars:
118+
dims = vars_to_dims.get(var.name, None)
119+
if var in free_rvs_to_values:
120+
new_var = free_rv(var, free_rvs_to_values[var], dims)
121+
elif var in observed_rvs_to_values:
122+
new_var = observed_rv(var, observed_rvs_to_values[var], dims)
123+
elif var in potentials:
124+
new_var = potential(var, dims)
125+
elif var in deterministics:
126+
new_var = deterministic(var, dims)
127+
else:
128+
raise RuntimeError(f"Variable is not RV, Potential nor Deterministic: {new_var}")
129+
new_vars.append(new_var)
130+
131+
toposort_replace(fgraph, tuple(zip(vars, new_vars)))
132+
return fgraph
133+
134+
135+
def model_from_fgraph(fgraph: FunctionGraph) -> Model:
136+
137+
model = Model(coords=getattr(fgraph, "coords", None))
138+
139+
# Get rid of dummy `ModelVar` Ops
140+
fgraph = fgraph.clone()
141+
model_vars_to_vars = {
142+
model_node.outputs[0]: model_node.inputs[0]
143+
for model_node in fgraph.apply_nodes
144+
if isinstance(model_node.op, ModelVar)
145+
}
146+
toposort_replace(fgraph, tuple(model_vars_to_vars.items()))
147+
148+
# Populate new PyMC model mappings
149+
for model_var in model_vars_to_vars.keys():
150+
if isinstance(model_var.owner.op, FreeRV):
151+
var, value, *dims = model_var.owner.inputs
152+
model.free_RVs.append(var)
153+
model.create_value_var(var, transform=None, value_var=value)
154+
model.set_initval(var, initval=None)
155+
elif isinstance(model_var.owner.op, ObservedRV):
156+
var, value, *dims = model_var.owner.inputs
157+
model.observed_RVs.append(var)
158+
model.create_value_var(var, transform=None, value_var=value)
159+
elif isinstance(model_var.owner.op, Potential):
160+
var, *dims = model_var.owner.inputs
161+
model.potentials.append(var)
162+
elif isinstance(model_var.owner.op, Deterministic):
163+
var, *dims = model_var.owner.inputs
164+
model.deterministics.append(var)
165+
else:
166+
continue # Raise?
167+
168+
if not dims:
169+
dims = None
170+
else:
171+
dims = [dim.data for dim in dims]
172+
model.add_named_variable(var, dims=dims)
173+
174+
return model

pymc_experimental/utils/pytensorf.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytensor
2+
from pytensor.graph import Constant, Type
3+
4+
5+
class StringType(Type[str]):
6+
def clone(self, **kwargs):
7+
return type(self)()
8+
9+
def filter(self, x, strict=False, allow_downcast=None):
10+
if isinstance(x, str):
11+
return x
12+
else:
13+
raise TypeError("Expected a string!")
14+
15+
def __str__(self):
16+
return "string"
17+
18+
@staticmethod
19+
def may_share_memory(a, b):
20+
return isinstance(a, str) and a is b
21+
22+
23+
stringtype = StringType()
24+
25+
26+
class StringConstant(Constant):
27+
pass
28+
29+
30+
@pytensor._as_symbolic.register(str)
31+
def as_symbolic_string(x, **kwargs):
32+
33+
return StringConstant(stringtype, x)

0 commit comments

Comments
 (0)