7
7
from pytensor .tensor .exceptions import NotScalarConstantError
8
8
9
9
from pymc_experimental .utils .model_fgraph import (
10
+ ModelDeterministic ,
10
11
ModelFreeRV ,
12
+ ModelNamed ,
13
+ ModelObservedRV ,
14
+ ModelPotential ,
11
15
ModelVar ,
12
16
fgraph_from_model ,
13
17
model_deterministic ,
@@ -23,11 +27,17 @@ def test_basic():
23
27
y = pm .Deterministic ("y" , x + 1 )
24
28
w = pm .HalfNormal ("w" , pm .math .exp (y ))
25
29
z = pm .Normal ("z" , y , w , observed = [0 , 1 , 2 ], dims = ("test_dim" ,))
26
- pm .Potential ("pot" , x * 2 )
30
+ pot = pm .Potential ("pot" , x * 2 )
27
31
28
- m_fgraph = fgraph_from_model (m_old )
32
+ m_fgraph , memo = fgraph_from_model (m_old )
29
33
assert isinstance (m_fgraph , FunctionGraph )
30
34
35
+ assert isinstance (memo [x ].owner .op , ModelFreeRV )
36
+ assert isinstance (memo [y ].owner .op , ModelDeterministic )
37
+ assert isinstance (memo [w ].owner .op , ModelFreeRV )
38
+ assert isinstance (memo [z ].owner .op , ModelObservedRV )
39
+ assert isinstance (memo [pot ].owner .op , ModelPotential )
40
+
31
41
m_new = model_from_fgraph (m_fgraph )
32
42
assert isinstance (m_new , pm .Model )
33
43
@@ -79,7 +89,12 @@ def test_data():
79
89
mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
80
90
obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
81
91
82
- m_new = model_from_fgraph (fgraph_from_model (m_old ))
92
+ m_fgraph , memo = fgraph_from_model (m_old )
93
+ assert isinstance (memo [x ].owner .op , ModelNamed )
94
+ assert isinstance (memo [y ].owner .op , ModelNamed )
95
+ assert isinstance (memo [b0 ].owner .op , ModelNamed )
96
+
97
+ m_new = model_from_fgraph (m_fgraph )
83
98
84
99
# ConstantData is preserved
85
100
assert m_new ["b0" ].data == m_old ["b0" ].data
@@ -125,7 +140,7 @@ def test_deterministics():
125
140
assert m ["y" ].owner .inputs [3 ] is m ["mu" ]
126
141
assert m ["y" ].owner .inputs [4 ] is not m ["sigma" ]
127
142
128
- fg = fgraph_from_model (m )
143
+ fg , _ = fgraph_from_model (m )
129
144
130
145
# Check that no Deterministics are in graph of x to y and y to z
131
146
x , y , z , det_mu , det_sigma , det_y_ , det_y__ = fg .outputs
@@ -173,7 +188,7 @@ def test_sub_model_error():
173
188
with pm .Model () as sub_m :
174
189
y = pm .Normal ("y" , x )
175
190
176
- nodes = [v for v in fgraph_from_model (m ).toposort () if not isinstance (v .op , ModelVar )]
191
+ nodes = [v for v in fgraph_from_model (m )[ 0 ] .toposort () if not isinstance (v .op , ModelVar )]
177
192
assert len (nodes ) == 2
178
193
assert isinstance (nodes [0 ].op , pm .Beta )
179
194
assert isinstance (nodes [1 ].op , pm .Normal )
@@ -234,7 +249,7 @@ def test_fgraph_rewrite(non_centered_rewrite):
234
249
subject_mean = pm .Normal ("subject_mean" , group_mean , group_std , dims = ("subject" ,))
235
250
obs = pm .Normal ("obs" , subject_mean , 1 , observed = np .zeros (10 ), dims = ("subject" ,))
236
251
237
- fg = fgraph_from_model (m_old )
252
+ fg , _ = fgraph_from_model (m_old )
238
253
non_centered_rewrite .apply (fg )
239
254
240
255
m_new = model_from_fgraph (fg )
0 commit comments