Skip to content

Commit d0af6b1

Browse files
New var_names kwarg for pm.model_to_graphviz (#5634)
Enables positive selection of model variables to be included in the model graph. Co-authored-by: Michael Osthege <[email protected]>
1 parent 66fba38 commit d0af6b1

File tree

2 files changed

+159
-90
lines changed

2 files changed

+159
-90
lines changed

pymc/model_graph.py

+79-89
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414
import warnings
1515

16-
from collections import defaultdict, deque
17-
from typing import Dict, Iterator, NewType, Optional, Set
16+
from collections import defaultdict
17+
from typing import Dict, Iterable, List, NewType, Optional, Set
1818

1919
from aesara import function
2020
from aesara.compile.sharedvalue import SharedVariable
21-
from aesara.graph.basic import walk
21+
from aesara.graph import Apply
22+
from aesara.graph.basic import ancestors, walk
2223
from aesara.tensor.random.op import RandomVariable
2324
from aesara.tensor.var import TensorConstant, TensorVariable
2425

@@ -32,85 +33,64 @@
3233
class ModelGraph:
3334
def __init__(self, model):
3435
self.model = model
35-
self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
36+
self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
3637
self.var_list = self.model.named_vars.values()
37-
self.transform_map = {
38-
v.transformed: v.name for v in self.var_list if hasattr(v, "transformed")
39-
}
40-
self._deterministics = None
41-
42-
def get_deterministics(self, var):
43-
"""Compute the deterministic nodes of the graph, **not** including var itself."""
44-
deterministics = []
45-
attrs = ("transformed", "logpt")
46-
for v in self.var_list:
47-
if v != var and all(not hasattr(v, attr) for attr in attrs):
48-
deterministics.append(v)
49-
return deterministics
50-
51-
def _get_ancestors(self, var: TensorVariable, func) -> Set[TensorVariable]:
52-
"""Get all ancestors of a function, doing some accounting for deterministics."""
53-
54-
# this contains all of the variables in the model EXCEPT var...
55-
vars = set(self.var_list)
56-
vars.remove(var)
57-
58-
blockers = set() # type: Set[TensorVariable]
59-
retval = set() # type: Set[TensorVariable]
60-
61-
def _expand(node) -> Optional[Iterator[TensorVariable]]:
62-
if node in blockers:
63-
return None
64-
elif node in vars:
65-
blockers.add(node)
66-
retval.add(node)
67-
return None
68-
elif node.owner:
69-
blockers.add(node)
70-
return reversed(node.owner.inputs)
71-
else:
72-
return None
73-
74-
list(walk(deque([func]), _expand, bfs=True))
75-
return retval
76-
77-
def _filter_parents(self, var, parents) -> Set[VarName]:
78-
"""Get direct parents of a var, as strings"""
79-
keep = set() # type: Set[VarName]
80-
for p in parents:
81-
if p == var:
82-
continue
83-
elif p.name in self.var_names:
84-
keep.add(p.name)
85-
elif p in self.transform_map:
86-
if self.transform_map[p] != var.name:
87-
keep.add(self.transform_map[p])
88-
else:
89-
raise AssertionError(f"Do not know what to do with {get_var_name(p)}")
90-
return keep
91-
92-
def get_parents(self, var: TensorVariable) -> Set[VarName]:
93-
"""Get the named nodes that are direct inputs to the var"""
94-
# TODO: Update these lines, variables no longer have a `logpt` attribute
95-
if hasattr(var, "transformed"):
96-
func = var.transformed.logpt
97-
elif hasattr(var, "logpt"):
98-
func = var.logpt
99-
else:
100-
func = var
10138

102-
parents = self._get_ancestors(var, func)
103-
return self._filter_parents(var, parents)
39+
def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
40+
if var.owner is None or var.owner.inputs is None:
41+
return set()
42+
43+
def _expand(x):
44+
if x.name:
45+
return [x]
46+
if isinstance(x.owner, Apply):
47+
return reversed(x.owner.inputs)
48+
return []
49+
50+
parents = {get_var_name(x) for x in walk(nodes=var.owner.inputs, expand=_expand) if x.name}
51+
52+
return parents
53+
54+
def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[VarName]:
55+
if var_names is None:
56+
return self._all_var_names
57+
58+
selected_names = set(var_names)
59+
60+
# .copy() because sets cannot change in size during iteration
61+
for var_name in selected_names.copy():
62+
if var_name not in self._all_var_names:
63+
raise ValueError(f"{var_name} is not in this model.")
64+
65+
for model_var in self.var_list:
66+
if hasattr(model_var.tag, "observations"):
67+
if model_var.tag.observations == self.model[var_name]:
68+
selected_names.add(model_var.name)
10469

105-
def make_compute_graph(self) -> Dict[str, Set[VarName]]:
70+
selected_ancestors = set(
71+
filter(
72+
lambda rv: rv.name in self._all_var_names,
73+
list(ancestors([self.model[var_name] for var_name in selected_names])),
74+
)
75+
)
76+
77+
for var in selected_ancestors.copy():
78+
if hasattr(var.tag, "observations"):
79+
selected_ancestors.add(var.tag.observations)
80+
81+
# ordering of self._all_var_names is important
82+
return [var.name for var in selected_ancestors]
83+
84+
def make_compute_graph(
85+
self, var_names: Optional[Iterable[VarName]] = None
86+
) -> Dict[VarName, Set[VarName]]:
10687
"""Get map of var_name -> set(input var names) for the model"""
107-
input_map = defaultdict(set) # type: Dict[str, Set[VarName]]
88+
input_map: Dict[VarName, Set[VarName]] = defaultdict(set)
10889

109-
for var_name in self.var_names:
90+
for var_name in self.vars_to_plot(var_names):
11091
var = self.model[var_name]
111-
key = var_name
112-
val = self.get_parents(var)
113-
input_map[key] = input_map[key].union(val)
92+
parent_name = self.get_parent_names(var)
93+
input_map[var_name] = input_map[var_name].union(parent_name)
11494

11595
if hasattr(var.tag, "observations"):
11696
try:
@@ -120,6 +100,7 @@ def make_compute_graph(self) -> Dict[str, Set[VarName]]:
120100
input_map[obs_name] = input_map[obs_name].union({var_name})
121101
except AttributeError:
122102
pass
103+
123104
return input_map
124105

125106
def _make_node(self, var_name, graph, *, formatting: str = "plain"):
@@ -168,18 +149,20 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
168149
def _eval(self, var):
169150
return function([], var, mode="FAST_COMPILE")()
170151

171-
def get_plates(self):
152+
def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]:
172153
"""Rough but surprisingly accurate plate detection.
173154
174155
Just groups by the shape of the underlying distribution. Will be wrong
175156
if there are two plates with the same shape.
176157
177158
Returns
178159
-------
179-
dict: str -> set[str]
160+
dict
161+
Maps plate labels to the set of ``VarName``s inside the plate.
180162
"""
181163
plates = defaultdict(set)
182-
for var_name in self.var_names:
164+
165+
for var_name in self.vars_to_plot(var_names):
183166
v = self.model[var_name]
184167
if var_name in self.model.RV_dims:
185168
plate_label = " x ".join(
@@ -189,9 +172,10 @@ def get_plates(self):
189172
else:
190173
plate_label = " x ".join(map(str, self._eval(v.shape)))
191174
plates[plate_label].add(var_name)
175+
192176
return plates
193177

194-
def make_graph(self, formatting: str = "plain"):
178+
def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
195179
"""Make graphviz Digraph of PyMC model
196180
197181
Returns
@@ -207,25 +191,29 @@ def make_graph(self, formatting: str = "plain"):
207191
"\tconda install -c conda-forge python-graphviz"
208192
)
209193
graph = graphviz.Digraph(self.model.name)
210-
for plate_label, var_names in self.get_plates().items():
194+
for plate_label, all_var_names in self.get_plates(var_names).items():
211195
if plate_label:
212196
# must be preceded by 'cluster' to get a box around it
213197
with graph.subgraph(name="cluster" + plate_label) as sub:
214-
for var_name in var_names:
198+
for var_name in all_var_names:
215199
self._make_node(var_name, sub, formatting=formatting)
216200
# plate label goes bottom right
217201
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
218202
else:
219-
for var_name in var_names:
203+
for var_name in all_var_names:
220204
self._make_node(var_name, graph, formatting=formatting)
221205

222-
for key, values in self.make_compute_graph().items():
223-
for value in values:
224-
graph.edge(value.replace(":", "&"), key.replace(":", "&"))
206+
for child, parents in self.make_compute_graph(var_names=var_names).items():
207+
# parents is a set of rv names that preceed child rv nodes
208+
for parent in parents:
209+
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))
210+
225211
return graph
226212

227213

228-
def model_to_graphviz(model=None, *, formatting: str = "plain"):
214+
def model_to_graphviz(
215+
model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
216+
):
229217
"""Produce a graphviz Digraph from a PyMC model.
230218
231219
Requires graphviz, which may be installed most easily with
@@ -240,7 +228,9 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
240228
----------
241229
model : pm.Model
242230
The model to plot. Not required when called from inside a modelcontext.
243-
formatting : str
231+
var_names : iterable of variable names, optional
232+
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
233+
formatting : str, optional
244234
one of { "plain" }
245235
246236
Examples
@@ -275,4 +265,4 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
275265
"Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2
276266
)
277267
model = pm.modelcontext(model)
278-
return ModelGraph(model).make_graph(formatting=formatting)
268+
return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)

pymc/tests/test_model_graph.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def setup_class(cls):
143143
def test_inputs(self):
144144
for child, parents_in_plot in self.compute_graph.items():
145145
var = self.model[child]
146-
parents_in_graph = self.model_graph.get_parents(var)
146+
parents_in_graph = self.model_graph.get_parent_names(var)
147147
if isinstance(var, SharedVariable):
148148
# observed data also doesn't have parents in the compute graph!
149149
# But for the visualization we like them to become decendants of the
@@ -183,6 +183,85 @@ def test_checks_formatting(self):
183183
model_to_graphviz(self.model, formatting="plain_with_params")
184184

185185

186+
def model_with_different_descendants():
187+
"""
188+
Model proposed by Michael to test variable selection functionality
189+
From here: https://github.com/pymc-devs/pymc/pull/5634#pullrequestreview-916297509
190+
"""
191+
with pm.Model() as pmodel2:
192+
a = pm.Normal("a")
193+
b = pm.Normal("b")
194+
pm.Normal("c", a * b)
195+
intermediate = pm.Deterministic("intermediate", a + b)
196+
pred = pm.Deterministic("pred", intermediate * 3)
197+
198+
obs = pm.ConstantData("obs", 1.75)
199+
200+
L = pm.Normal("L", mu=1 + 0.5 * pred, observed=obs)
201+
202+
return pmodel2
203+
204+
205+
class TestParents:
206+
@pytest.mark.parametrize(
207+
"var_name, parent_names",
208+
[
209+
("L", {"pred"}),
210+
("pred", {"intermediate"}),
211+
("intermediate", {"a", "b"}),
212+
("c", {"a", "b"}),
213+
("a", set()),
214+
("b", set()),
215+
],
216+
)
217+
def test_get_parent_names(self, var_name, parent_names):
218+
mg = ModelGraph(model_with_different_descendants())
219+
mg.get_parent_names(mg.model[var_name]) == parent_names
220+
221+
222+
class TestVariableSelection:
223+
@pytest.mark.parametrize(
224+
"var_names, vars_to_plot, compute_graph",
225+
[
226+
(["c"], ["a", "b", "c"], {"c": {"a", "b"}, "a": set(), "b": set()}),
227+
(
228+
["L"],
229+
["pred", "obs", "L", "intermediate", "a", "b"],
230+
{
231+
"pred": {"intermediate"},
232+
"obs": {"L"},
233+
"L": {"pred"},
234+
"intermediate": {"a", "b"},
235+
"a": set(),
236+
"b": set(),
237+
},
238+
),
239+
(
240+
["obs"],
241+
["pred", "obs", "L", "intermediate", "a", "b"],
242+
{
243+
"pred": {"intermediate"},
244+
"obs": {"L"},
245+
"L": {"pred"},
246+
"intermediate": {"a", "b"},
247+
"a": set(),
248+
"b": set(),
249+
},
250+
),
251+
# selecting ["c", "L"] is akin to selecting the entire graph
252+
(
253+
["c", "L"],
254+
ModelGraph(model_with_different_descendants()).vars_to_plot(),
255+
ModelGraph(model_with_different_descendants()).make_compute_graph(),
256+
),
257+
],
258+
)
259+
def test_subgraph(self, var_names, vars_to_plot, compute_graph):
260+
mg = ModelGraph(model_with_different_descendants())
261+
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
262+
assert mg.make_compute_graph(var_names=var_names) == compute_graph
263+
264+
186265
class TestImputationModel(BaseModelGraphTest):
187266
model_func = model_with_imputations
188267

0 commit comments

Comments
 (0)