@@ -58,104 +58,87 @@ def test_argmax():
58
58
59
59
60
60
@pytest .mark .parametrize (
61
- "if_max " ,
61
+ "pt_op " ,
62
62
[
63
- True ,
64
- False ,
63
+ pt . max ,
64
+ pt . min ,
65
65
],
66
66
)
67
- def test_non_iid_fails (if_max ):
67
+ def test_non_iid_fails (pt_op ):
68
68
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
69
69
x = pm .Normal .dist ([0 , 1 , 2 , 3 , 4 ], 1 , shape = (5 ,))
70
70
x .name = "x"
71
- if if_max == True :
72
- x_m = pt .max (x , axis = - 1 )
73
- x_m_value = pt .vector ("x_max_value" )
74
- else :
75
- x_m = pt .min (x , axis = - 1 )
76
- x_m_value = pt .vector ("x_min_value" )
71
+ x_m = pt_op (x , axis = - 1 )
72
+ x_m_value = pt .vector ("x_value" )
77
73
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
78
74
x_max_logprob = logp (x_m , x_m_value )
79
75
80
76
81
77
@pytest .mark .parametrize (
82
- "if_max" ,
83
- [True , False ],
78
+ "pt_op" ,
79
+ [
80
+ pt .max ,
81
+ pt .min ,
82
+ ],
84
83
)
85
- def test_non_rv_fails (if_max ):
84
+ def test_non_rv_fails (pt_op ):
86
85
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
87
86
x = pt .exp (pt .random .beta (0 , 1 , size = (3 ,)))
88
87
x .name = "x"
89
- if if_max == True :
90
- x_m = pt .max (x , axis = - 1 )
91
- x_m_value = pt .vector ("x_max_value" )
92
- else :
93
- x_m = pt .min (x , axis = - 1 )
94
- x_m_value = pt .vector ("x_min_value" )
88
+ x_m = pt_op (x , axis = - 1 )
89
+ x_m_value = pt .vector ("x_value" )
95
90
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
96
91
x_max_logprob = logp (x_m , x_m_value )
97
92
98
93
99
94
@pytest .mark .parametrize (
100
- "if_max " ,
95
+ "pt_op " ,
101
96
[
102
- True ,
103
- False ,
97
+ pt . max ,
98
+ pt . min ,
104
99
],
105
100
)
106
- def test_multivariate_rv_fails (if_max ):
101
+ def test_multivariate_rv_fails (pt_op ):
107
102
_alpha = pt .scalar ()
108
103
_k = pt .iscalar ()
109
104
x = pm .StickBreakingWeights .dist (_alpha , _k )
110
105
x .name = "x"
111
- if if_max == True :
112
- x_m = pt .max (x , axis = - 1 )
113
- x_m_value = pt .vector ("x_max_value" )
114
- else :
115
- x_m = pt .min (x , axis = - 1 )
116
- x_m_value = pt .vector ("x_min_value" )
106
+ x_m = pt_op (x , axis = - 1 )
107
+ x_m_value = pt .vector ("x_value" )
117
108
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
118
109
x_max_logprob = logp (x_m , x_m_value )
119
110
120
111
121
112
@pytest .mark .parametrize (
122
- "if_max " ,
113
+ "pt_op " ,
123
114
[
124
- True ,
125
- False ,
115
+ pt . max ,
116
+ pt . min ,
126
117
],
127
118
)
128
- def test_categorical (if_max ):
119
+ def test_categorical (pt_op ):
129
120
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
130
121
x = pm .Categorical .dist ([1 , 1 , 1 , 1 ], shape = (5 ,))
131
122
x .name = "x"
132
- if if_max == True :
133
- x_m = pt .max (x , axis = - 1 )
134
- x_m_value = pt .vector ("x_max_value" )
135
- else :
136
- x_m = pt .min (x , axis = - 1 )
137
- x_m_value = pt .vector ("x_min_value" )
123
+ x_m = pt_op (x , axis = - 1 )
124
+ x_m_value = pt .vector ("x_value" )
138
125
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
139
126
x_max_logprob = logp (x_m , x_m_value )
140
127
141
128
142
129
@pytest .mark .parametrize (
143
- "if_max " ,
130
+ "pt_op " ,
144
131
[
145
- True ,
146
- False ,
132
+ pt . max ,
133
+ pt . min ,
147
134
],
148
135
)
149
- def test_non_supp_axis (if_max ):
136
+ def test_non_supp_axis (pt_op ):
150
137
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
151
138
x = pt .random .normal (0 , 1 , size = (3 , 3 ))
152
139
x .name = "x"
153
- if if_max == True :
154
- x_m = pt .max (x , axis = - 1 )
155
- x_m_value = pt .vector ("x_max_value" )
156
- else :
157
- x_m = pt .min (x , axis = - 1 )
158
- x_m_value = pt .vector ("x_min_value" )
140
+ x_m = pt_op (x , axis = - 1 )
141
+ x_m_value = pt .vector ("x_value" )
159
142
with pytest .raises (RuntimeError , match = re .escape ("Logprob method not implemented" )):
160
143
x_max_logprob = logp (x_m , x_m_value )
161
144
0 commit comments