11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import platform
14
15
15
16
import numpy as np
16
17
import pymc as pm
17
18
18
19
# general imports
19
20
import pytensor
21
+
22
+ pytensor .config .floatX = "float32"
20
23
import pytest
21
24
import scipy .stats .distributions as sp
22
25
@@ -50,6 +53,12 @@ class TestGenExtremeClass:
50
53
reason = "PyMC underflows earlier than scipy on float32" ,
51
54
)
52
55
def test_logp (self ):
56
+ def ref_logp (value , mu , sigma , xi ):
57
+ if 1 + xi * (value - mu ) / sigma > 0 :
58
+ return sp .genextreme .logpdf (value , c = - xi , loc = mu , scale = sigma )
59
+ else :
60
+ return - np .inf
61
+
53
62
check_logp (
54
63
GenExtreme ,
55
64
R ,
@@ -58,15 +67,24 @@ def test_logp(self):
58
67
"sigma" : Rplusbig ,
59
68
"xi" : Domain ([- 1 , - 0.99 , - 0.5 , 0 , 0.5 , 0.99 , 1 ]),
60
69
},
61
- lambda value , mu , sigma , xi : sp .genextreme .logpdf (value , c = - xi , loc = mu , scale = sigma )
62
- if 1 + xi * (value - mu ) / sigma > 0
63
- else - np .inf ,
70
+ ref_logp ,
71
+ n_samples = - 1 ,
64
72
)
65
73
66
74
if pytensor .config .floatX == "float32" :
67
75
raise Exception ("Flaky test: It passed this time, but XPASS is not allowed." )
68
76
77
+ @pytest .mark .skipif (
78
+ (pytensor .config .floatX == "float32" and platform .system () == "Windows" ),
79
+ reason = "Scipy gives different results on Windows and does not match with desired accuracy" ,
80
+ )
69
81
def test_logcdf (self ):
82
+ def ref_logcdf (value , mu , sigma , xi ):
83
+ if 1 + xi * (value - mu ) / sigma > 0 :
84
+ return sp .genextreme .logcdf (value , c = - xi , loc = mu , scale = sigma )
85
+ else :
86
+ return - np .inf
87
+
70
88
check_logcdf (
71
89
GenExtreme ,
72
90
R ,
@@ -75,9 +93,7 @@ def test_logcdf(self):
75
93
"sigma" : Rplusbig ,
76
94
"xi" : Domain ([- 1 , - 0.99 , - 0.5 , 0 , 0.5 , 0.99 , 1 ]),
77
95
},
78
- lambda value , mu , sigma , xi : sp .genextreme .logcdf (value , c = - xi , loc = mu , scale = sigma )
79
- if 1 + xi * (value - mu ) / sigma > 0
80
- else - np .inf ,
96
+ ref_logcdf ,
81
97
decimal = select_by_precision (float64 = 6 , float32 = 2 ),
82
98
)
83
99
0 commit comments