43
43
import pymc as pm
44
44
45
45
from pymc import logp
46
- from pymc .logprob import conditional_logp
47
46
from pymc .testing import assert_no_rvs
48
47
49
48
@@ -58,55 +57,112 @@ def test_argmax():
58
57
x_max_logprob = logp (x_max , x_max_value )
59
58
60
59
61
- def test_max_non_iid_fails ():
62
- """Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
60
+ @pytest .mark .parametrize (
61
+ "if_max" ,
62
+ [
63
+ True ,
64
+ False ,
65
+ ],
66
+ )
67
+ def test_non_iid_fails (if_max ):
68
+ """Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
63
69
x = pm .Normal .dist ([0 , 1 , 2 , 3 , 4 ], 1 , shape = (5 ,))
64
70
x .name = "x"
65
- x_max = pt .max (x , axis = - 1 )
66
- x_max_value = pt .vector ("x_max_value" )
71
+ if if_max == True :
72
+ x_m = pt .max (x , axis = - 1 )
73
+ x_m_value = pt .vector ("x_max_value" )
74
+ else :
75
+ x_min = pt .min (x , axis = - 1 )
76
+ x_m = x_min .owner .inputs [0 ]
77
+ x_m_value = pt .vector ("x_min_value" )
67
78
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
68
- x_max_logprob = logp (x_max , x_max_value )
79
+ x_max_logprob = logp (x_m , x_m_value )
69
80
70
81
71
- def test_max_non_rv_fails ():
82
+ @pytest .mark .parametrize (
83
+ "if_max" ,
84
+ [True , False ],
85
+ )
86
+ def test_non_rv_fails (if_max ):
72
87
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
73
88
x = pt .exp (pt .random .beta (0 , 1 , size = (3 ,)))
74
89
x .name = "x"
75
- x_max = pt .max (x , axis = - 1 )
76
- x_max_value = pt .vector ("x_max_value" )
90
+ if if_max == True :
91
+ x_m = pt .max (x , axis = - 1 )
92
+ x_m_value = pt .vector ("x_max_value" )
93
+ else :
94
+ x_min = pt .min (x , axis = - 1 )
95
+ x_m = x_min .owner .inputs [0 ]
96
+ x_m_value = pt .vector ("x_min_value" )
77
97
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
78
- x_max_logprob = logp (x_max , x_max_value )
98
+ x_max_logprob = logp (x_m , x_m_value )
79
99
80
100
81
- def test_max_multivariate_rv_fails ():
101
+ @pytest .mark .parametrize (
102
+ "if_max" ,
103
+ [
104
+ True ,
105
+ False ,
106
+ ],
107
+ )
108
+ def test_multivariate_rv_fails (if_max ):
82
109
_alpha = pt .scalar ()
83
110
_k = pt .iscalar ()
84
111
x = pm .StickBreakingWeights .dist (_alpha , _k )
85
112
x .name = "x"
86
- x_max = pt .max (x , axis = - 1 )
87
- x_max_value = pt .vector ("x_max_value" )
113
+ if if_max == True :
114
+ x_m = pt .max (x , axis = - 1 )
115
+ x_m_value = pt .vector ("x_max_value" )
116
+ else :
117
+ x_min = pt .min (x , axis = - 1 )
118
+ x_m = x_min .owner .inputs [0 ]
119
+ x_m_value = pt .vector ("x_min_value" )
88
120
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
89
- x_max_logprob = logp (x_max , x_max_value )
121
+ x_max_logprob = logp (x_m , x_m_value )
90
122
91
123
92
- def test_max_categorical ():
124
+ @pytest .mark .parametrize (
125
+ "if_max" ,
126
+ [
127
+ True ,
128
+ False ,
129
+ ],
130
+ )
131
+ def test_categorical (if_max ):
93
132
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
94
133
x = pm .Categorical .dist ([1 , 1 , 1 , 1 ], shape = (5 ,))
95
134
x .name = "x"
96
- x_max = pt .max (x , axis = - 1 )
97
- x_max_value = pt .vector ("x_max_value" )
135
+ if if_max == True :
136
+ x_m = pt .max (x , axis = - 1 )
137
+ x_m_value = pt .vector ("x_max_value" )
138
+ else :
139
+ x_min = pt .min (x , axis = - 1 )
140
+ x_m = x_min .owner .inputs [0 ]
141
+ x_m_value = pt .vector ("x_min_value" )
98
142
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
99
- x_max_logprob = logp (x_max , x_max_value )
143
+ x_max_logprob = logp (x_m , x_m_value )
100
144
101
145
102
- def test_non_supp_axis_max ():
146
+ @pytest .mark .parametrize (
147
+ "if_max" ,
148
+ [
149
+ True ,
150
+ False ,
151
+ ],
152
+ )
153
+ def test_non_supp_axis (if_max ):
103
154
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
104
155
x = pt .random .normal (0 , 1 , size = (3 , 3 ))
105
156
x .name = "x"
106
- x_max = pt .max (x , axis = - 1 )
107
- x_max_value = pt .vector ("x_max_value" )
157
+ if if_max == True :
158
+ x_m = pt .max (x , axis = - 1 )
159
+ x_m_value = pt .vector ("x_max_value" )
160
+ else :
161
+ x_min = pt .min (x , axis = - 1 )
162
+ x_m = x_min .owner .inputs [0 ]
163
+ x_m_value = pt .vector ("x_min_value" )
108
164
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
109
- x_max_logprob = logp (x_max , x_max_value )
165
+ x_max_logprob = logp (x_m , x_m_value )
110
166
111
167
112
168
@pytest .mark .parametrize (
@@ -147,3 +203,54 @@ def test_max_logprob(shape, value, axis):
147
203
(x_max_logprob .eval ({x_max_value : test_value })),
148
204
rtol = 1e-06 ,
149
205
)
206
+
207
+
208
+ @pytest .mark .parametrize (
209
+ "shape, value, axis" ,
210
+ [
211
+ (3 , 0.85 , - 1 ),
212
+ (3 , 0.01 , 0 ),
213
+ (2 , 0.2 , None ),
214
+ (4 , 0.5 , 0 ),
215
+ ((3 , 4 ), 0.9 , None ),
216
+ ((3 , 4 ), 0.75 , (1 , 0 )),
217
+ ],
218
+ )
219
+ def test_min_logprob (shape , value , axis ):
220
+ """Test whether the logprob for ```pt.mix``` produces the corrected
221
+ The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
222
+ U_1, \\ dots, U_n \\ stackrel{\t ext{i.i.d.}}{\\ sim} \t ext{Uniform}(0, 1) \\ Rightarrow U_{(k)} \\ sim \t ext{Beta}(k, n + 1- k)
223
+ for all 1<=k<=n
224
+ """
225
+ x = pt .random .uniform (0 , 1 , size = shape )
226
+ x .name = "x"
227
+ x_min = pt .min (x , axis = axis )
228
+ x_min_rv = x_min .owner .inputs [0 ]
229
+ x_min_value = pt .scalar ("x_min_value" )
230
+ x_min_logprob = logp (x_min_rv , x_min_value )
231
+
232
+ assert_no_rvs (x_min_logprob )
233
+
234
+ test_value = value
235
+
236
+ n = np .prod (shape )
237
+ beta_rv = pt .random .beta (1 , n , name = "beta" )
238
+ beta_vv = beta_rv .clone ()
239
+ beta_rv_logprob = logp (beta_rv , beta_vv )
240
+
241
+ np .testing .assert_allclose (
242
+ beta_rv_logprob .eval ({beta_vv : test_value }),
243
+ (x_min_logprob .eval ({x_min_value : test_value })),
244
+ rtol = 1e-06 ,
245
+ )
246
+
247
+
248
+ def test_min_non_mul_elemwise_fails ():
249
+ """Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly"""
250
+ x = pt .log (pt .random .beta (0 , 1 , size = (3 ,)))
251
+ x .name = "x"
252
+ x_min = pt .min (x , axis = - 1 )
253
+ x_min_rv = x_min .owner .inputs [0 ]
254
+ x_min_value = pt .vector ("x_min_value" )
255
+ with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
256
+ x_min_logprob = logp (x_min_rv , x_min_value )
0 commit comments