Skip to content

Refactor Flat and HalfFlat distributions #4723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 46 additions & 36 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,31 +308,36 @@ def logcdf(value, lower, upper):
)


class FlatRV(RandomVariable):
name = "flat"
ndim_supp = 0
ndims_params = []
dtype = "floatX"
_print_name = ("Flat", "\\operatorname{Flat}")

@classmethod
def rng_fn(cls, rng, size):
raise NotImplementedError("Cannot sample from flat variable")


flat = FlatRV()


class Flat(Continuous):
"""
Uninformative log-likelihood that returns 0 regardless of
the passed value.
"""

def __init__(self, *args, **kwargs):
self._default = 0
super().__init__(defaults=("_default",), *args, **kwargs)

def random(self, point=None, size=None):
"""Raises ValueError as it is not possible to sample from Flat distribution

Parameters
----------
point: dict, optional
size: int, optional
rv_op = flat

Raises
-------
ValueError
"""
raise ValueError("Cannot sample from Flat distribution")
@classmethod
def dist(cls, *, size=None, testval=None, **kwargs):
if testval is None:
testval = np.full(size, floatX(0.0))
Comment on lines +336 to +337
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming these test values were necessary because, when aesara.config.compute_test_value is enabled, RandomVariable.perform is called in order to generate a test value automatically, and it will fail.

In general, test value-related logic like this needs to be conditioned on aesara.config.compute_test_value != "off"; otherwise, we'll end up doing unnecessary work. This case might be different, though.

Copy link
Member Author

@ricardoV94 ricardoV94 May 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Flat testval is critical (in the codebase) in a couple of places (e.g., for getting the model.initial_point for sampling).

Copy link
Contributor

@brandonwillard brandonwillard May 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just keep in mind that v4 cannot actually require test values, and anywhere that currently does is a place we need to refactor.

Aside from that, it should always be possible to enable test values and have them work throughout v4, it's just not mandatory or automatically enabled at any point within PyMC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, I am adding a note about this in #4567

return super().dist([], size=size, testval=testval, **kwargs)

def logp(self, value):
def logp(value):
"""
Calculate log-probability of Flat distribution at specified value.

Expand All @@ -348,7 +353,7 @@ def logp(self, value):
"""
return at.zeros_like(value)

def logcdf(self, value):
def logcdf(value):
"""
Compute the log of the cumulative distribution function for Flat distribution
at the specified value.
Expand All @@ -368,28 +373,33 @@ def logcdf(self, value):
)


class HalfFlat(PositiveContinuous):
"""Improper flat prior over the positive reals."""
class HalfFlatRV(RandomVariable):
name = "half_flat"
ndim_supp = 0
ndims_params = []
dtype = "floatX"
_print_name = ("HalfFlat", "\\operatorname{HalfFlat}")

def __init__(self, *args, **kwargs):
self._default = 1
super().__init__(defaults=("_default",), *args, **kwargs)
@classmethod
def rng_fn(cls, rng, size):
raise NotImplementedError("Cannot sample from half_flat variable")

def random(self, point=None, size=None):
"""Raises ValueError as it is not possible to sample from HalfFlat distribution

Parameters
----------
point: dict, optional
size: int, optional
halfflat = HalfFlatRV()

Raises
-------
ValueError
"""
raise ValueError("Cannot sample from HalfFlat distribution")

def logp(self, value):
class HalfFlat(PositiveContinuous):
"""Improper flat prior over the positive reals."""

rv_op = halfflat

@classmethod
def dist(cls, *, size=None, testval=None, **kwargs):
if testval is None:
testval = np.full(size, floatX(1.0))
return super().dist([], size=size, testval=testval, **kwargs)

def logp(value):
"""
Calculate log-probability of HalfFlat distribution at specified value.

Expand All @@ -405,7 +415,7 @@ def logp(self, value):
"""
return bound(at.zeros_like(value), value > 0)

def logcdf(self, value):
def logcdf(value):
"""
Compute the log of the cumulative distribution function for HalfFlat distribution
at the specified value.
Expand Down
37 changes: 16 additions & 21 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,18 +961,16 @@ def test_discrete_unif(self):
assert logpt(invalid_dist, 0.5).eval() == -np.inf
assert logcdf(invalid_dist, 2).eval() == -np.inf

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_flat(self):
self.check_logp(Flat, Runif, {}, lambda value: 0)
with Model():
x = Flat("a")
assert_allclose(x.tag.test_value, 0)
self.check_logcdf(Flat, R, {}, lambda value: np.log(0.5))
# Check infinite cases individually.
assert 0.0 == logcdf(Flat.dist(), np.inf).tag.test_value
assert -np.inf == logcdf(Flat.dist(), -np.inf).tag.test_value
assert 0.0 == logcdf(Flat.dist(), np.inf).eval()
assert -np.inf == logcdf(Flat.dist(), -np.inf).eval()

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_half_flat(self):
self.check_logp(HalfFlat, Rplus, {}, lambda value: 0)
with Model():
Expand All @@ -981,8 +979,8 @@ def test_half_flat(self):
assert x.tag.test_value.shape == (2,)
self.check_logcdf(HalfFlat, Rplus, {}, lambda value: -np.inf)
# Check infinite cases individually.
assert 0.0 == logcdf(HalfFlat.dist(), np.inf).tag.test_value
assert -np.inf == logcdf(HalfFlat.dist(), -np.inf).tag.test_value
assert 0.0 == logcdf(HalfFlat.dist(), np.inf).eval()
assert -np.inf == logcdf(HalfFlat.dist(), -np.inf).eval()

def test_normal(self):
self.check_logp(
Expand Down Expand Up @@ -2499,17 +2497,19 @@ def test_vonmises(self):
lambda value, mu, kappa: floatX(sp.vonmises.logpdf(value, kappa, loc=mu)),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_gumbel(self):
def gumbel(value, mu, beta):
return floatX(sp.gumbel_r.logpdf(value, loc=mu, scale=beta))

self.check_logp(Gumbel, R, {"mu": R, "beta": Rplusbig}, gumbel)

def gumbellcdf(value, mu, beta):
return floatX(sp.gumbel_r.logcdf(value, loc=mu, scale=beta))

self.check_logcdf(Gumbel, R, {"mu": R, "beta": Rplusbig}, gumbellcdf)
self.check_logp(
Gumbel,
R,
{"mu": R, "beta": Rplusbig},
lambda value, mu, beta: sp.gumbel_r.logpdf(value, loc=mu, scale=beta),
)
self.check_logcdf(
Gumbel,
R,
{"mu": R, "beta": Rplusbig},
lambda value, mu, beta: sp.gumbel_r.logcdf(value, loc=mu, scale=beta),
)

def test_logistic(self):
self.check_logp(
Expand Down Expand Up @@ -2538,11 +2538,6 @@ def test_logitnormal(self):
decimal=select_by_precision(float64=6, float32=1),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_multidimensional_beta_construction(self):
with Model():
Beta("beta", alpha=1.0, beta=1.0, size=(10, 20))

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Some combinations underflow to -inf in float32 in pymc version",
Expand Down
50 changes: 35 additions & 15 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,11 @@ def check_rv_size(self):
assert actual == expected, f"size={size}, expected={expected}, actual={actual}"

# test multi-parameters sampling for univariate distributions (with univariate inputs)
if self.pymc_dist.rv_op.ndim_supp == 0 and sum(self.pymc_dist.rv_op.ndims_params) == 0:
if (
self.pymc_dist.rv_op.ndim_supp == 0
and self.pymc_dist.rv_op.ndims_params
and sum(self.pymc_dist.rv_op.ndims_params) == 0
):
params = {
k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items()
}
Expand Down Expand Up @@ -394,6 +398,36 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
)


class TestFlat(BaseTestDistribution):
pymc_dist = pm.Flat
pymc_dist_params = {}
expected_rv_op_params = {}
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_not_implemented",
]

def check_not_implemented(self):
with pytest.raises(NotImplementedError):
self.pymc_rv.eval()


class TestHalfFlat(BaseTestDistribution):
pymc_dist = pm.HalfFlat
pymc_dist_params = {}
expected_rv_op_params = {}
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_not_implemented",
]

def check_not_implemented(self):
with pytest.raises(NotImplementedError):
self.pymc_rv.eval()


class TestDiscreteWeibull(BaseTestDistribution):
def discrete_weibul_rng_fn(self, size, q, beta, uniform_rng_fct):
return np.ceil(np.power(np.log(1 - uniform_rng_fct(size=size)) / np.log(q), 1.0 / beta)) - 1
Expand Down Expand Up @@ -1240,20 +1274,6 @@ def ref_rand(size, mu, sigma, nu):

pymc3_random(pm.ExGaussian, {"mu": R, "sigma": Rplus, "nu": Rplus}, ref_rand=ref_rand)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_flat(self):
with pm.Model():
f = pm.Flat("f")
with pytest.raises(ValueError):
f.random(1)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_half_flat(self):
with pm.Model():
f = pm.HalfFlat("f")
with pytest.raises(ValueError):
f.random(1)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_matrix_normal(self):
def ref_rand(size, mu, rowcov, colcov):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_distributions_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymc3.tests.helpers import select_by_precision

# pytestmark = pytest.mark.usefixtures("seeded_test")
pytestmark = pytest.mark.xfail(reason="This test relies on the deprecated Distribution interface")
pytestmark = pytest.mark.xfail(reason="Timeseries not refactored")


def test_AR():
Expand Down
3 changes: 1 addition & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,6 @@ def test_sum_normal(self):
_, pval = stats.kstest(ppc["b"], stats.norm(scale=scale).cdf)
assert pval > 0.001

@pytest.mark.xfail(reason="HalfFlat not refactored for v4")
def test_model_not_drawable_prior(self):
data = np.random.poisson(lam=10, size=200)
model = pm.Model()
Expand All @@ -613,7 +612,7 @@ def test_model_not_drawable_prior(self):
trace = pm.sample(tune=1000)

with model:
with pytest.raises(ValueError) as excinfo:
with pytest.raises(NotImplementedError) as excinfo:
pm.sample_prior_predictive(50)
assert "Cannot sample" in str(excinfo.value)
samples = pm.sample_posterior_predictive(trace, 40)
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def test_step_categorical(self):
trace = sample(8000, tune=0, step=step, start=start, model=model, random_seed=1)
self.check_stat(check, trace, step.__class__.__name__)

@pytest.mark.xfail(reason="Flat not refactored for v4")
@pytest.mark.xfail(reason="EllipticalSlice not refactored for v4")
def test_step_elliptical_slice(self):
start, model, (K, L, mu, std, noise) = mv_prior_simple()
unc = noise ** 0.5
Expand Down