Skip to content

Adding check to prevent mixing of discrete and continuous distributions #5629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 4, 2022
33 changes: 15 additions & 18 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,16 @@
from pymc.distributions import transforms
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Discrete,
Distribution,
SymbolicDistribution,
_moment,
moment,
)
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
from pymc.distributions.logprob import logcdf, logp
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.util import check_dist_not_registered
from pymc.vartypes import discrete_types
from pymc.vartypes import continuous_types, discrete_types

__all__ = ["Mixture", "NormalMixture"]


def all_discrete(comp_dists):
"""
Determine if all distributions in comp_dists are discrete
"""
if isinstance(comp_dists, Distribution):
return isinstance(comp_dists, Discrete)
else:
return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists)


class MarginalMixtureRV(OpFromGraph):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

Expand Down Expand Up @@ -182,6 +166,19 @@ def dist(cls, w, comp_dists, **kwargs):
UserWarning,
)

if len(comp_dists) > 1:
all_continuous = all(comp_dist.dtype in continuous_types for comp_dist in comp_dists)
all_discrete = all(comp_dist.dtype in discrete_types for comp_dist in comp_dists)

if not (
all(comp_dist.dtype in continuous_types for comp_dist in comp_dists)
or all(comp_dist.dtype in discrete_types for comp_dist in comp_dists)
):
raise ValueError(
"All distributions in comp_dists must be either discrete or continuous.\n"
"See the following issue for more information: https://github.com/pymc-devs/pymc/issues/4511."
)

# Check that components are not associated with a registered variable in the model
components_ndim_supp = set()
for dist in comp_dists:
Expand Down
15 changes: 15 additions & 0 deletions pymc/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,21 @@ def test_broadcast_components(self, comp_dists, expected_shape):
assert isinstance(comp_dist.owner.op, RandomVariable)
assert tuple(comp_dist.shape.eval()) == expected_shape

def test_preventing_mixing_cont_and_discrete(self):
with pytest.raises(
ValueError,
match="All distributions in comp_dists must be either discrete or continuous.",
):
with Model() as model:
mix = Mixture(
"x",
w=[0.5, 0.3, 0.2],
comp_dists=[
Categorical.dist(np.tile(1 / 3, 3)),
Normal.dist(np.ones(3), 3),
],
)


class TestNormalMixture(SeededTest):
def test_normal_mixture_sampling(self):
Expand Down