13
13
# limitations under the License.
14
14
import warnings
15
15
16
- from collections import defaultdict , deque
17
- from typing import Dict , Iterator , NewType , Optional , Set
16
+ from collections import defaultdict
17
+ from typing import Dict , Iterable , List , NewType , Optional , Set
18
18
19
19
from aesara import function
20
20
from aesara .compile .sharedvalue import SharedVariable
21
- from aesara .graph .basic import walk
21
+ from aesara .graph import Apply
22
+ from aesara .graph .basic import ancestors , walk
22
23
from aesara .tensor .random .op import RandomVariable
23
24
from aesara .tensor .var import TensorConstant , TensorVariable
24
25
32
33
class ModelGraph :
33
34
def __init__ (self , model ):
34
35
self .model = model
35
- self .var_names = get_default_varnames (self .model .named_vars , include_transformed = False )
36
+ self ._all_var_names = get_default_varnames (self .model .named_vars , include_transformed = False )
36
37
self .var_list = self .model .named_vars .values ()
37
- self .transform_map = {
38
- v .transformed : v .name for v in self .var_list if hasattr (v , "transformed" )
39
- }
40
- self ._deterministics = None
41
-
42
- def get_deterministics (self , var ):
43
- """Compute the deterministic nodes of the graph, **not** including var itself."""
44
- deterministics = []
45
- attrs = ("transformed" , "logpt" )
46
- for v in self .var_list :
47
- if v != var and all (not hasattr (v , attr ) for attr in attrs ):
48
- deterministics .append (v )
49
- return deterministics
50
-
51
- def _get_ancestors (self , var : TensorVariable , func ) -> Set [TensorVariable ]:
52
- """Get all ancestors of a function, doing some accounting for deterministics."""
53
-
54
- # this contains all of the variables in the model EXCEPT var...
55
- vars = set (self .var_list )
56
- vars .remove (var )
57
-
58
- blockers = set () # type: Set[TensorVariable]
59
- retval = set () # type: Set[TensorVariable]
60
-
61
- def _expand (node ) -> Optional [Iterator [TensorVariable ]]:
62
- if node in blockers :
63
- return None
64
- elif node in vars :
65
- blockers .add (node )
66
- retval .add (node )
67
- return None
68
- elif node .owner :
69
- blockers .add (node )
70
- return reversed (node .owner .inputs )
71
- else :
72
- return None
73
-
74
- list (walk (deque ([func ]), _expand , bfs = True ))
75
- return retval
76
-
77
- def _filter_parents (self , var , parents ) -> Set [VarName ]:
78
- """Get direct parents of a var, as strings"""
79
- keep = set () # type: Set[VarName]
80
- for p in parents :
81
- if p == var :
82
- continue
83
- elif p .name in self .var_names :
84
- keep .add (p .name )
85
- elif p in self .transform_map :
86
- if self .transform_map [p ] != var .name :
87
- keep .add (self .transform_map [p ])
88
- else :
89
- raise AssertionError (f"Do not know what to do with { get_var_name (p )} " )
90
- return keep
91
-
92
- def get_parents (self , var : TensorVariable ) -> Set [VarName ]:
93
- """Get the named nodes that are direct inputs to the var"""
94
- # TODO: Update these lines, variables no longer have a `logpt` attribute
95
- if hasattr (var , "transformed" ):
96
- func = var .transformed .logpt
97
- elif hasattr (var , "logpt" ):
98
- func = var .logpt
99
- else :
100
- func = var
101
38
102
- parents = self ._get_ancestors (var , func )
103
- return self ._filter_parents (var , parents )
39
+ def get_parent_names (self , var : TensorVariable ) -> Set [VarName ]:
40
+ if var .owner is None or var .owner .inputs is None :
41
+ return set ()
42
+
43
+ def _expand (x ):
44
+ if x .name :
45
+ return [x ]
46
+ if isinstance (x .owner , Apply ):
47
+ return reversed (x .owner .inputs )
48
+ return []
49
+
50
+ parents = {get_var_name (x ) for x in walk (nodes = var .owner .inputs , expand = _expand ) if x .name }
51
+
52
+ return parents
53
+
54
+ def vars_to_plot (self , var_names : Optional [Iterable [VarName ]] = None ) -> List [VarName ]:
55
+ if var_names is None :
56
+ return self ._all_var_names
57
+
58
+ selected_names = set (var_names )
59
+
60
+ # .copy() because sets cannot change in size during iteration
61
+ for var_name in selected_names .copy ():
62
+ if var_name not in self ._all_var_names :
63
+ raise ValueError (f"{ var_name } is not in this model." )
64
+
65
+ for model_var in self .var_list :
66
+ if hasattr (model_var .tag , "observations" ):
67
+ if model_var .tag .observations == self .model [var_name ]:
68
+ selected_names .add (model_var .name )
104
69
105
- def make_compute_graph (self ) -> Dict [str , Set [VarName ]]:
70
+ selected_ancestors = set (
71
+ filter (
72
+ lambda rv : rv .name in self ._all_var_names ,
73
+ list (ancestors ([self .model [var_name ] for var_name in selected_names ])),
74
+ )
75
+ )
76
+
77
+ for var in selected_ancestors .copy ():
78
+ if hasattr (var .tag , "observations" ):
79
+ selected_ancestors .add (var .tag .observations )
80
+
81
+ # ordering of self._all_var_names is important
82
+ return [var .name for var in selected_ancestors ]
83
+
84
+ def make_compute_graph (
85
+ self , var_names : Optional [Iterable [VarName ]] = None
86
+ ) -> Dict [VarName , Set [VarName ]]:
106
87
"""Get map of var_name -> set(input var names) for the model"""
107
- input_map = defaultdict ( set ) # type : Dict[str , Set[VarName]]
88
+ input_map : Dict [VarName , Set [VarName ]] = defaultdict ( set )
108
89
109
- for var_name in self .var_names :
90
+ for var_name in self .vars_to_plot ( var_names ) :
110
91
var = self .model [var_name ]
111
- key = var_name
112
- val = self .get_parents (var )
113
- input_map [key ] = input_map [key ].union (val )
92
+ parent_name = self .get_parent_names (var )
93
+ input_map [var_name ] = input_map [var_name ].union (parent_name )
114
94
115
95
if hasattr (var .tag , "observations" ):
116
96
try :
@@ -120,6 +100,7 @@ def make_compute_graph(self) -> Dict[str, Set[VarName]]:
120
100
input_map [obs_name ] = input_map [obs_name ].union ({var_name })
121
101
except AttributeError :
122
102
pass
103
+
123
104
return input_map
124
105
125
106
def _make_node (self , var_name , graph , * , formatting : str = "plain" ):
@@ -168,18 +149,20 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
168
149
def _eval (self , var ):
169
150
return function ([], var , mode = "FAST_COMPILE" )()
170
151
171
- def get_plates (self ) :
152
+ def get_plates (self , var_names : Optional [ Iterable [ VarName ]] = None ) -> Dict [ str , Set [ VarName ]] :
172
153
"""Rough but surprisingly accurate plate detection.
173
154
174
155
Just groups by the shape of the underlying distribution. Will be wrong
175
156
if there are two plates with the same shape.
176
157
177
158
Returns
178
159
-------
179
- dict: str -> set[str]
160
+ dict
161
+ Maps plate labels to the set of ``VarName``s inside the plate.
180
162
"""
181
163
plates = defaultdict (set )
182
- for var_name in self .var_names :
164
+
165
+ for var_name in self .vars_to_plot (var_names ):
183
166
v = self .model [var_name ]
184
167
if var_name in self .model .RV_dims :
185
168
plate_label = " x " .join (
@@ -189,9 +172,10 @@ def get_plates(self):
189
172
else :
190
173
plate_label = " x " .join (map (str , self ._eval (v .shape )))
191
174
plates [plate_label ].add (var_name )
175
+
192
176
return plates
193
177
194
- def make_graph (self , formatting : str = "plain" ):
178
+ def make_graph (self , var_names : Optional [ Iterable [ VarName ]] = None , formatting : str = "plain" ):
195
179
"""Make graphviz Digraph of PyMC model
196
180
197
181
Returns
@@ -207,25 +191,29 @@ def make_graph(self, formatting: str = "plain"):
207
191
"\t conda install -c conda-forge python-graphviz"
208
192
)
209
193
graph = graphviz .Digraph (self .model .name )
210
- for plate_label , var_names in self .get_plates ().items ():
194
+ for plate_label , all_var_names in self .get_plates (var_names ).items ():
211
195
if plate_label :
212
196
# must be preceded by 'cluster' to get a box around it
213
197
with graph .subgraph (name = "cluster" + plate_label ) as sub :
214
- for var_name in var_names :
198
+ for var_name in all_var_names :
215
199
self ._make_node (var_name , sub , formatting = formatting )
216
200
# plate label goes bottom right
217
201
sub .attr (label = plate_label , labeljust = "r" , labelloc = "b" , style = "rounded" )
218
202
else :
219
- for var_name in var_names :
203
+ for var_name in all_var_names :
220
204
self ._make_node (var_name , graph , formatting = formatting )
221
205
222
- for key , values in self .make_compute_graph ().items ():
223
- for value in values :
224
- graph .edge (value .replace (":" , "&" ), key .replace (":" , "&" ))
206
+ for child , parents in self .make_compute_graph (var_names = var_names ).items ():
207
+ # parents is a set of rv names that preceed child rv nodes
208
+ for parent in parents :
209
+ graph .edge (parent .replace (":" , "&" ), child .replace (":" , "&" ))
210
+
225
211
return graph
226
212
227
213
228
- def model_to_graphviz (model = None , * , formatting : str = "plain" ):
214
+ def model_to_graphviz (
215
+ model = None , * , var_names : Optional [Iterable [VarName ]] = None , formatting : str = "plain"
216
+ ):
229
217
"""Produce a graphviz Digraph from a PyMC model.
230
218
231
219
Requires graphviz, which may be installed most easily with
@@ -240,7 +228,9 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
240
228
----------
241
229
model : pm.Model
242
230
The model to plot. Not required when called from inside a modelcontext.
243
- formatting : str
231
+ var_names : iterable of variable names, optional
232
+ Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
233
+ formatting : str, optional
244
234
one of { "plain" }
245
235
246
236
Examples
@@ -275,4 +265,4 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
275
265
"Formattings other than 'plain' are currently not supported." , UserWarning , stacklevel = 2
276
266
)
277
267
model = pm .modelcontext (model )
278
- return ModelGraph (model ).make_graph (formatting = formatting )
268
+ return ModelGraph (model ).make_graph (var_names = var_names , formatting = formatting )
0 commit comments