Skip to content

Refactor Rice and Skew Normal distribution #4705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 51 additions & 109 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3102,6 +3102,21 @@ def logp(value, mu, kappa):
)


class SkewNormalRV(RandomVariable):
name = "skewnormal"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
_print_name = ("SkewNormal", "\\operatorname{SkewNormal}")

@classmethod
def rng_fn(cls, rng, mu, sigma, alpha, size=None):
return stats.skewnorm.rvs(a=alpha, loc=mu, scale=sigma, size=size, random_state=rng)


skewnormal = SkewNormalRV()


class SkewNormal(Continuous):
r"""
Univariate skew-normal log-likelihood.
Expand Down Expand Up @@ -3160,51 +3175,25 @@ class SkewNormal(Continuous):
approaching plus/minus infinite we get a half-normal distribution.

"""
rv_op = skewnormal

def __init__(self, mu=0.0, sigma=None, tau=None, alpha=1, sd=None, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def dist(cls, alpha=1, mu=0.0, sigma=None, tau=None, sd=None, *args, **kwargs):
if sd is not None:
sigma = sd

tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.mu = mu = at.as_tensor_variable(floatX(mu))
self.tau = at.as_tensor_variable(tau)
self.sigma = self.sd = at.as_tensor_variable(sigma)

self.alpha = alpha = at.as_tensor_variable(floatX(alpha))

self.mean = mu + self.sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha ** 2) ** 0.5
self.variance = self.sigma ** 2 * (1 - (2 * alpha ** 2) / ((1 + alpha ** 2) * np.pi))
alpha = at.as_tensor_variable(floatX(alpha))
mu = at.as_tensor_variable(floatX(mu))
tau = at.as_tensor_variable(tau)
sigma = at.as_tensor_variable(sigma)

assert_negative_support(tau, "tau", "SkewNormal")
assert_negative_support(sigma, "sigma", "SkewNormal")

def random(self, point=None, size=None):
"""
Draw random values from SkewNormal distribution.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
# mu, tau, _, alpha = draw_values(
# [self.mu, self.tau, self.sigma, self.alpha], point=point, size=size
# )
# return generate_samples(
# stats.skewnorm.rvs, a=alpha, loc=mu, scale=tau ** -0.5, dist_shape=self.shape, size=size
# )
return super().dist([mu, sigma, alpha], *args, **kwargs)

def logp(self, value):
def logp(value, mu, sigma, alpha):
"""
Calculate log-probability of SkewNormal distribution at specified value.

Expand All @@ -3218,20 +3207,14 @@ def logp(self, value):
-------
TensorVariable
"""
tau = self.tau
sigma = self.sigma
mu = self.mu
alpha = self.alpha
tau, sigma = get_tau_sigma(sigma=sigma)
return bound(
at.log(1 + at.erf(((value - mu) * at.sqrt(tau) * alpha) / at.sqrt(2)))
+ (-tau * (value - mu) ** 2 + at.log(tau / np.pi / 2.0)) / 2.0,
tau > 0,
sigma > 0,
)

def _distr_parameters_for_repr(self):
return ["mu", "sigma", "alpha"]


class Triangular(BoundedContinuous):
r"""
Expand Down Expand Up @@ -3474,6 +3457,21 @@ def logcdf(
)


class RiceRV(RandomVariable):
name = "rice"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Rice", "\\operatorname{Rice}")

@classmethod
def rng_fn(cls, rng, b, sigma, size=None):
return stats.rice.rvs(b=b, scale=sigma, size=size, random_state=rng)


rice = RiceRV()


class Rice(PositiveContinuous):
r"""
Rice distribution.
Expand Down Expand Up @@ -3533,42 +3531,21 @@ class Rice(PositiveContinuous):
b = \dfrac{\nu}{\sigma}

"""
rv_op = rice

def __init__(self, nu=None, sigma=None, b=None, sd=None, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def dist(cls, nu=None, sigma=None, b=None, sd=None, *args, **kwargs):
if sd is not None:
sigma = sd

nu, b, sigma = self.get_nu_b(nu, b, sigma)
self.nu = nu = at.as_tensor_variable(floatX(nu))
self.sigma = self.sd = sigma = at.as_tensor_variable(floatX(sigma))
self.b = b = at.as_tensor_variable(floatX(b))

nu_sigma_ratio = -(nu ** 2) / (2 * sigma ** 2)
self.mean = (
sigma
* np.sqrt(np.pi / 2)
* at.exp(nu_sigma_ratio / 2)
* (
(1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2)
- nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2)
)
)
self.variance = (
2 * sigma ** 2
+ nu ** 2
- (np.pi * sigma ** 2 / 2)
* (
at.exp(nu_sigma_ratio / 2)
* (
(1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2)
- nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2)
)
)
** 2
)
nu, b, sigma = cls.get_nu_b(nu, b, sigma)
b = at.as_tensor_variable(floatX(b))
sigma = at.as_tensor_variable(floatX(sigma))

def get_nu_b(self, nu, b, sigma):
return super().dist([b, sigma], *args, **kwargs)

@classmethod
def get_nu_b(cls, nu, b, sigma):
if sigma is None:
sigma = 1.0
if nu is None and b is not None:
Expand All @@ -3579,35 +3556,7 @@ def get_nu_b(self, nu, b, sigma):
return nu, b, sigma
raise ValueError("Rice distribution must specify either nu" " or b.")

def random(self, point=None, size=None):
"""
Draw random values from Rice distribution.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
# nu, sigma = draw_values([self.nu, self.sigma], point=point, size=size)
# return generate_samples(self._random, nu=nu, sigma=sigma, dist_shape=self.shape, size=size)

def _random(self, nu, sigma, size):
"""Wrapper around stats.rice.rvs that converts Rice's
parametrization to scipy.rice. All parameter arrays should have
been broadcasted properly by generate_samples at this point and size is
the scipy.rvs representation.
"""
return stats.rice.rvs(b=nu / sigma, scale=sigma, size=size)

def logp(self, value):
def logp(value, b, sigma):
"""
Calculate log-probability of Rice distribution at specified value.

Expand All @@ -3621,20 +3570,13 @@ def logp(self, value):
-------
TensorVariable
"""
nu = self.nu
sigma = self.sigma
b = self.b
x = value / sigma
return bound(
at.log(x * at.exp((-(x - b) * (x - b)) / 2) * i0e(x * b) / sigma),
sigma >= 0,
nu >= 0,
value > 0,
)

def _distr_parameters_for_repr(self):
return ["nu", "sigma"]


class Logistic(Continuous):
r"""
Expand Down
16 changes: 10 additions & 6 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,6 @@ def test_half_studentt(self):
lambda value, sigma: sp.halfcauchy.logpdf(value, 0, sigma),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_skew_normal(self):
self.check_logp(
SkewNormal,
Expand Down Expand Up @@ -2545,19 +2544,24 @@ def test_multidimensional_beta_construction(self):
with Model():
Beta("beta", alpha=1.0, beta=1.0, size=(10, 20))

@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Some combinations underflow to -inf in float32 in pymc version",
)
def test_rice(self):
self.check_logp(
Rice,
Rplus,
{"nu": Rplus, "sigma": Rplusbig},
lambda value, nu, sigma: sp.rice.logpdf(value, b=nu / sigma, loc=0, scale=sigma),
{"b": Rplus, "sigma": Rplusbig},
lambda value, b, sigma: sp.rice.logpdf(value, b=b, loc=0, scale=sigma),
)

def test_rice_nu(self):
self.check_logp(
Rice,
Rplus,
{"b": Rplus, "sigma": Rplusbig},
lambda value, b, sigma: sp.rice.logpdf(value, b=b, loc=0, scale=sigma),
{"nu": Rplus, "sigma": Rplusbig},
lambda value, nu, sigma: sp.rice.logpdf(value, b=nu / sigma, loc=0, scale=sigma),
)

def test_moyal_logp(self):
Expand Down
50 changes: 43 additions & 7 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,6 @@ class TestTruncatedNormalUpper(BaseTestCases.BaseTestCase):
params = {"mu": 0.0, "tau": 1.0, "upper": 0.5}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestSkewNormal(BaseTestCases.BaseTestCase):
distribution = pm.SkewNormal
params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestWald(BaseTestCases.BaseTestCase):
distribution = pm.Wald
Expand Down Expand Up @@ -514,6 +508,49 @@ def seeded_kumaraswamy_rng_fn(self):
]


class TestSkewNormal(BaseTestDistribution):
pymc_dist = pm.SkewNormal
pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
expected_rv_op_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
reference_dist_params = {"loc": 0.0, "scale": 1.0, "a": 5.0}
reference_dist = seeded_scipy_distribution_builder("skewnorm")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestSkewNormalTau(BaseTestDistribution):
pymc_dist = pm.SkewNormal
tau, sigma = get_tau_sigma(tau=2.0)
pymc_dist_params = {"mu": 0.0, "tau": tau, "alpha": 5.0}
expected_rv_op_params = {"mu": 0.0, "sigma": sigma, "alpha": 5.0}
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestRice(BaseTestDistribution):
pymc_dist = pm.Rice
b, sigma = 1, 2
pymc_dist_params = {"b": b, "sigma": sigma}
expected_rv_op_params = {"b": b, "sigma": sigma}
reference_dist_params = {"b": b, "scale": sigma}
reference_dist = seeded_scipy_distribution_builder("rice")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestRiceNu(BaseTestDistribution):
pymc_dist = pm.Rice
nu = sigma = 2
pymc_dist_params = {"nu": nu, "sigma": sigma}
expected_rv_op_params = {"b": nu / sigma, "sigma": sigma}
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestStudentTLam(BaseTestDistribution):
pymc_dist = pm.StudentT
lam, sigma = get_tau_sigma(tau=2.0)
Expand Down Expand Up @@ -1145,7 +1182,6 @@ def ref_rand(size, mu, sigma, upper):
pm.TruncatedNormal, {"mu": R, "sigma": Rplusbig, "upper": Rplusbig}, ref_rand=ref_rand
)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_skew_normal(self):
def ref_rand(size, alpha, mu, sigma):
return st.skewnorm.rvs(size=size, a=alpha, loc=mu, scale=sigma)
Expand Down