43
43
import pymc as pm
44
44
45
45
from pymc import logp
46
- from pymc .logprob import conditional_logp
47
46
from pymc .testing import assert_no_rvs
48
47
49
48
@@ -58,55 +57,90 @@ def test_argmax():
58
57
x_max_logprob = logp (x_max , x_max_value )
59
58
60
59
61
- def test_max_non_iid_fails ():
62
- """Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
60
+ @pytest .mark .parametrize (
61
+ "pt_op" ,
62
+ [
63
+ pt .max ,
64
+ pt .min ,
65
+ ],
66
+ )
67
+ def test_non_iid_fails (pt_op ):
68
+ """Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
63
69
x = pm .Normal .dist ([0 , 1 , 2 , 3 , 4 ], 1 , shape = (5 ,))
64
70
x .name = "x"
65
- x_max = pt . max (x , axis = - 1 )
66
- x_max_value = pt .vector ("x_max_value " )
71
+ x_m = pt_op (x , axis = - 1 )
72
+ x_m_value = pt .vector ("x_value " )
67
73
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
68
- x_max_logprob = logp (x_max , x_max_value )
74
+ x_max_logprob = logp (x_m , x_m_value )
69
75
70
76
71
- def test_max_non_rv_fails ():
77
+ @pytest .mark .parametrize (
78
+ "pt_op" ,
79
+ [
80
+ pt .max ,
81
+ pt .min ,
82
+ ],
83
+ )
84
+ def test_non_rv_fails (pt_op ):
72
85
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
73
86
x = pt .exp (pt .random .beta (0 , 1 , size = (3 ,)))
74
87
x .name = "x"
75
- x_max = pt . max (x , axis = - 1 )
76
- x_max_value = pt .vector ("x_max_value " )
88
+ x_m = pt_op (x , axis = - 1 )
89
+ x_m_value = pt .vector ("x_value " )
77
90
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
78
- x_max_logprob = logp (x_max , x_max_value )
91
+ x_max_logprob = logp (x_m , x_m_value )
79
92
80
93
81
- def test_max_multivariate_rv_fails ():
94
+ @pytest .mark .parametrize (
95
+ "pt_op" ,
96
+ [
97
+ pt .max ,
98
+ pt .min ,
99
+ ],
100
+ )
101
+ def test_multivariate_rv_fails (pt_op ):
82
102
_alpha = pt .scalar ()
83
103
_k = pt .iscalar ()
84
104
x = pm .StickBreakingWeights .dist (_alpha , _k )
85
105
x .name = "x"
86
- x_max = pt . max (x , axis = - 1 )
87
- x_max_value = pt .vector ("x_max_value " )
106
+ x_m = pt_op (x , axis = - 1 )
107
+ x_m_value = pt .vector ("x_value " )
88
108
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
89
- x_max_logprob = logp (x_max , x_max_value )
109
+ x_max_logprob = logp (x_m , x_m_value )
90
110
91
111
92
- def test_max_categorical ():
112
+ @pytest .mark .parametrize (
113
+ "pt_op" ,
114
+ [
115
+ pt .max ,
116
+ pt .min ,
117
+ ],
118
+ )
119
+ def test_categorical (pt_op ):
93
120
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
94
121
x = pm .Categorical .dist ([1 , 1 , 1 , 1 ], shape = (5 ,))
95
122
x .name = "x"
96
- x_max = pt . max (x , axis = - 1 )
97
- x_max_value = pt .vector ("x_max_value " )
123
+ x_m = pt_op (x , axis = - 1 )
124
+ x_m_value = pt .vector ("x_value " )
98
125
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
99
- x_max_logprob = logp (x_max , x_max_value )
126
+ x_max_logprob = logp (x_m , x_m_value )
100
127
101
128
102
- def test_non_supp_axis_max ():
129
+ @pytest .mark .parametrize (
130
+ "pt_op" ,
131
+ [
132
+ pt .max ,
133
+ pt .min ,
134
+ ],
135
+ )
136
+ def test_non_supp_axis (pt_op ):
103
137
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
104
138
x = pt .random .normal (0 , 1 , size = (3 , 3 ))
105
139
x .name = "x"
106
- x_max = pt . max (x , axis = - 1 )
107
- x_max_value = pt .vector ("x_max_value " )
140
+ x_m = pt_op (x , axis = - 1 )
141
+ x_m_value = pt .vector ("x_value " )
108
142
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
109
- x_max_logprob = logp (x_max , x_max_value )
143
+ x_max_logprob = logp (x_m , x_m_value )
110
144
111
145
112
146
@pytest .mark .parametrize (
@@ -147,3 +181,52 @@ def test_max_logprob(shape, value, axis):
147
181
(x_max_logprob .eval ({x_max_value : test_value })),
148
182
rtol = 1e-06 ,
149
183
)
184
+
185
+
186
+ @pytest .mark .parametrize (
187
+ "shape, value, axis" ,
188
+ [
189
+ (3 , 0.85 , - 1 ),
190
+ (3 , 0.01 , 0 ),
191
+ (2 , 0.2 , None ),
192
+ (4 , 0.5 , 0 ),
193
+ ((3 , 4 ), 0.9 , None ),
194
+ ((3 , 4 ), 0.75 , (1 , 0 )),
195
+ ],
196
+ )
197
+ def test_min_logprob (shape , value , axis ):
198
+ """Test whether the logprob for ```pt.mix``` produces the corrected
199
+ The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
200
+ U_1, \\ dots, U_n \\ stackrel{\t ext{i.i.d.}}{\\ sim} \t ext{Uniform}(0, 1) \\ Rightarrow U_{(k)} \\ sim \t ext{Beta}(k, n + 1- k)
201
+ for all 1<=k<=n
202
+ """
203
+ x = pt .random .uniform (0 , 1 , size = shape )
204
+ x .name = "x"
205
+ x_min = pt .min (x , axis = axis )
206
+ x_min_value = pt .scalar ("x_min_value" )
207
+ x_min_logprob = logp (x_min , x_min_value )
208
+
209
+ assert_no_rvs (x_min_logprob )
210
+
211
+ test_value = value
212
+
213
+ n = np .prod (shape )
214
+ beta_rv = pt .random .beta (1 , n , name = "beta" )
215
+ beta_vv = beta_rv .clone ()
216
+ beta_rv_logprob = logp (beta_rv , beta_vv )
217
+
218
+ np .testing .assert_allclose (
219
+ beta_rv_logprob .eval ({beta_vv : test_value }),
220
+ (x_min_logprob .eval ({x_min_value : test_value })),
221
+ rtol = 1e-06 ,
222
+ )
223
+
224
+
225
+ def test_min_non_mul_elemwise_fails ():
226
+ """Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly"""
227
+ x = pt .log (pt .random .beta (0 , 1 , size = (3 ,)))
228
+ x .name = "x"
229
+ x_min = pt .min (x , axis = - 1 )
230
+ x_min_value = pt .vector ("x_min_value" )
231
+ with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
232
+ x_min_logprob = logp (x_min , x_min_value )
0 commit comments