Skip to content

Commit 07c61a7

Browse files
committed
Skip GenExtreme logcdf test on float32 and Windows
1 parent fa90019 commit 07c61a7

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

pymc_experimental/tests/distributions/test_continuous.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import platform
1415

1516
import numpy as np
1617
import pymc as pm
1718

1819
# general imports
1920
import pytensor
21+
22+
pytensor.config.floatX = "float32"
2023
import pytest
2124
import scipy.stats.distributions as sp
2225

@@ -50,6 +53,12 @@ class TestGenExtremeClass:
5053
reason="PyMC underflows earlier than scipy on float32",
5154
)
5255
def test_logp(self):
56+
def ref_logp(value, mu, sigma, xi):
57+
if 1 + xi * (value - mu) / sigma > 0:
58+
return sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
59+
else:
60+
return -np.inf
61+
5362
check_logp(
5463
GenExtreme,
5564
R,
@@ -58,15 +67,24 @@ def test_logp(self):
5867
"sigma": Rplusbig,
5968
"xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
6069
},
61-
lambda value, mu, sigma, xi: sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
62-
if 1 + xi * (value - mu) / sigma > 0
63-
else -np.inf,
70+
ref_logp,
71+
n_samples=-1,
6472
)
6573

6674
if pytensor.config.floatX == "float32":
6775
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
6876

77+
@pytest.mark.skipif(
78+
(pytensor.config.floatX == "float32" and platform.system() == "Windows"),
79+
reason="Scipy gives different results on Windows and does not match with desired accuracy",
80+
)
6981
def test_logcdf(self):
82+
def ref_logcdf(value, mu, sigma, xi):
83+
if 1 + xi * (value - mu) / sigma > 0:
84+
return sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
85+
else:
86+
return -np.inf
87+
7088
check_logcdf(
7189
GenExtreme,
7290
R,
@@ -75,9 +93,7 @@ def test_logcdf(self):
7593
"sigma": Rplusbig,
7694
"xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
7795
},
78-
lambda value, mu, sigma, xi: sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
79-
if 1 + xi * (value - mu) / sigma > 0
80-
else -np.inf,
96+
ref_logcdf,
8197
decimal=select_by_precision(float64=6, float32=2),
8298
)
8399

0 commit comments

Comments
 (0)