Skip to content

Commit c6f2f31

Browse files
Update tests following distributions refactoring
The distributions refactoring moves the random variable sampling to aesara. This relies on numpy and scipy random variables implementation. So, now the only thing we care about testing is that the parametrization on the PyMC side is sendible given the one on the Aesara side (effectively the numpy/scipy one) More details can be found on issue pymc-devs#4554 pymc-devs#4554
1 parent ab41e0d commit c6f2f31

File tree

1 file changed

+73
-65
lines changed

1 file changed

+73
-65
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
import pytest
2424
import scipy.stats as st
2525

26+
from numpy.testing import assert_almost_equal
2627
from scipy import linalg
2728
from scipy.special import expit
2829

2930
import pymc3 as pm
3031

31-
from pymc3.aesaraf import change_rv_size, floatX, intX
32-
from pymc3.distributions.dist_math import clipped_beta_rvs
32+
from pymc3.aesaraf import floatX, intX
33+
from pymc3.distributions import change_rv_size
3334
from pymc3.distributions.shape_utils import to_tuple
3435
from pymc3.exceptions import ShapeError
3536
from pymc3.tests.helpers import SeededTest
@@ -540,6 +541,76 @@ def test_dirichlet_random_shape(self, shape, size):
540541
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape
541542

542543

544+
class TestCorrectParametrizationMappingPymcToScipy(SeededTest):
545+
@staticmethod
546+
def get_inputs_from_apply_node_outputs(outputs):
547+
parents = outputs.get_parents()
548+
if not parents:
549+
raise Exception("Parent Apply node missing for output")
550+
# I am assuming there will always only be 1 Apply parent node in this context
551+
return parents[0].inputs
552+
553+
def test_pymc_params_match_rv_ones(
554+
self, pymc_params, expected_aesara_params, pymc_dist, decimal=6
555+
):
556+
pymc_dist_output = pymc_dist.dist(**dict(pymc_params))
557+
aesera_dist_inputs = self.get_inputs_from_apply_node_outputs(pymc_dist_output)[3:]
558+
assert len(expected_aesara_params) == len(aesera_dist_inputs)
559+
for (expected_name, expected_value), actual_variable in zip(
560+
expected_aesara_params, aesera_dist_inputs
561+
):
562+
assert_almost_equal(expected_value, actual_variable.eval(), decimal=decimal)
563+
564+
def test_normal(self):
565+
params = [("mu", 5.0), ("sigma", 10.0)]
566+
self.test_pymc_params_match_rv_ones(params, params, pm.Normal)
567+
568+
def test_uniform(self):
569+
params = [("lower", 0.5), ("upper", 1.5)]
570+
self.test_pymc_params_match_rv_ones(params, params, pm.Uniform)
571+
572+
def test_half_normal(self):
573+
params, expected_aesara_params = [("sigma", 10.0)], [("mean", 0), ("sigma", 10.0)]
574+
self.test_pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
575+
576+
def test_beta_alpha_beta(self):
577+
params = [("alpha", 2.0), ("beta", 5.0)]
578+
self.test_pymc_params_match_rv_ones(params, params, pm.Beta)
579+
580+
def test_beta_mu_sigma(self):
581+
params = [("mu", 2.0), ("sigma", 5.0)]
582+
expected_alpha, expected_beta = pm.Beta.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
583+
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
584+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Beta)
585+
586+
@pytest.mark.skip(reason="Expected to fail due to bug")
587+
def test_exponential(self):
588+
params = [("lam", 10.0)]
589+
expected_params = [("lam", 1 / params[0][1])]
590+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
591+
592+
def test_cauchy(self):
593+
params = [("alpha", 2.0), ("beta", 5.0)]
594+
self.test_pymc_params_match_rv_ones(params, params, pm.Cauchy)
595+
596+
def test_half_cauchy(self):
597+
params = [("alpha", 2.0), ("beta", 5.0)]
598+
self.test_pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
599+
600+
@pytest.mark.skip(reason="Expected to fail due to bug")
601+
def test_gamma_alpha_beta(self):
602+
params = [("alpha", 2.0), ("beta", 5.0)]
603+
expected_params = [("alpha", params[0][1]), ("beta", 1 / params[1][1])]
604+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
605+
606+
@pytest.mark.skip(reason="Expected to fail due to bug")
607+
def test_gamma_mu_sigma(self):
608+
params = [("mu", 2.0), ("sigma", 5.0)]
609+
expected_alpha, expected_beta = pm.Gamma.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
610+
expected_params = [("alpha", expected_alpha), ("beta", 1 / expected_beta)]
611+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
612+
613+
543614
class TestScalarParameterSamples(SeededTest):
544615
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
545616
def test_bounded(self):
@@ -551,20 +622,6 @@ def ref_rand(size, tau):
551622

552623
pymc3_random(BoundedNormal, {"tau": Rplus}, ref_rand=ref_rand)
553624

554-
@pytest.mark.skip(reason="This test is covered by Aesara")
555-
def test_uniform(self):
556-
def ref_rand(size, lower, upper):
557-
return st.uniform.rvs(size=size, loc=lower, scale=upper - lower)
558-
559-
pymc3_random(pm.Uniform, {"lower": -Rplus, "upper": Rplus}, ref_rand=ref_rand)
560-
561-
@pytest.mark.skip(reason="This test is covered by Aesara")
562-
def test_normal(self):
563-
def ref_rand(size, mu, sigma):
564-
return st.norm.rvs(size=size, loc=mu, scale=sigma)
565-
566-
pymc3_random(pm.Normal, {"mu": R, "sigma": Rplus}, ref_rand=ref_rand)
567-
568625
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
569626
def test_truncated_normal(self):
570627
def ref_rand(size, mu, sigma, lower, upper):
@@ -603,13 +660,6 @@ def ref_rand(size, alpha, mu, sigma):
603660

604661
pymc3_random(pm.SkewNormal, {"mu": R, "sigma": Rplus, "alpha": R}, ref_rand=ref_rand)
605662

606-
@pytest.mark.skip(reason="This test is covered by Aesara")
607-
def test_half_normal(self):
608-
def ref_rand(size, tau):
609-
return st.halfnorm.rvs(size=size, loc=0, scale=tau ** -0.5)
610-
611-
pymc3_random(pm.HalfNormal, {"tau": Rplus}, ref_rand=ref_rand)
612-
613663
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
614664
def test_wald(self):
615665
# Cannot do anything too exciting as scipy wald is a
@@ -623,20 +673,6 @@ def ref_rand(size, mu, lam, alpha):
623673
ref_rand=ref_rand,
624674
)
625675

626-
@pytest.mark.skip(reason="This test is covered by Aesara")
627-
def test_beta(self):
628-
def ref_rand(size, alpha, beta):
629-
return clipped_beta_rvs(a=alpha, b=beta, size=size)
630-
631-
pymc3_random(pm.Beta, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
632-
633-
@pytest.mark.skip(reason="This test is covered by Aesara")
634-
def test_exponential(self):
635-
def ref_rand(size, lam):
636-
return nr.exponential(scale=1.0 / lam, size=size)
637-
638-
pymc3_random(pm.Exponential, {"lam": Rplus}, ref_rand=ref_rand)
639-
640676
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
641677
def test_laplace(self):
642678
def ref_rand(size, mu, b):
@@ -670,34 +706,6 @@ def ref_rand(size, nu, mu, lam):
670706

671707
pymc3_random(pm.StudentT, {"nu": Rplus, "mu": R, "lam": Rplus}, ref_rand=ref_rand)
672708

673-
@pytest.mark.skip(reason="This test is covered by Aesara")
674-
def test_cauchy(self):
675-
def ref_rand(size, alpha, beta):
676-
return st.cauchy.rvs(alpha, beta, size=size)
677-
678-
pymc3_random(pm.Cauchy, {"alpha": R, "beta": Rplusbig}, ref_rand=ref_rand)
679-
680-
@pytest.mark.skip(reason="This test is covered by Aesara")
681-
def test_half_cauchy(self):
682-
def ref_rand(size, beta):
683-
return st.halfcauchy.rvs(scale=beta, size=size)
684-
685-
pymc3_random(pm.HalfCauchy, {"beta": Rplusbig}, ref_rand=ref_rand)
686-
687-
@pytest.mark.skip(reason="This test is covered by Aesara")
688-
def test_gamma_alpha_beta(self):
689-
def ref_rand(size, alpha, beta):
690-
return st.gamma.rvs(alpha, scale=1.0 / beta, size=size)
691-
692-
pymc3_random(pm.Gamma, {"alpha": Rplusbig, "beta": Rplusbig}, ref_rand=ref_rand)
693-
694-
@pytest.mark.skip(reason="This test is covered by Aesara")
695-
def test_gamma_mu_sigma(self):
696-
def ref_rand(size, mu, sigma):
697-
return st.gamma.rvs(mu ** 2 / sigma ** 2, scale=sigma ** 2 / mu, size=size)
698-
699-
pymc3_random(pm.Gamma, {"mu": Rplusbig, "sigma": Rplusbig}, ref_rand=ref_rand)
700-
701709
@pytest.mark.skip(reason="This test is covered by Aesara")
702710
def test_inverse_gamma(self):
703711
def ref_rand(size, alpha, beta):

0 commit comments

Comments
 (0)