diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 16e3561608..0b93a0c579 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -29,32 +29,16 @@ from pymc.distributions import transforms from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters -from pymc.distributions.distribution import ( - Discrete, - Distribution, - SymbolicDistribution, - _moment, - moment, -) +from pymc.distributions.distribution import SymbolicDistribution, _moment, moment from pymc.distributions.logprob import logcdf, logp from pymc.distributions.shape_utils import to_tuple from pymc.distributions.transforms import _default_transform from pymc.util import check_dist_not_registered -from pymc.vartypes import discrete_types +from pymc.vartypes import continuous_types, discrete_types __all__ = ["Mixture", "NormalMixture"] -def all_discrete(comp_dists): - """ - Determine if all distributions in comp_dists are discrete - """ - if isinstance(comp_dists, Distribution): - return isinstance(comp_dists, Discrete) - else: - return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists) - - class MarginalMixtureRV(OpFromGraph): """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" @@ -182,6 +166,19 @@ def dist(cls, w, comp_dists, **kwargs): UserWarning, ) + if len(comp_dists) > 1: + all_continuous = all(comp_dist.dtype in continuous_types for comp_dist in comp_dists) + all_discrete = all(comp_dist.dtype in discrete_types for comp_dist in comp_dists) + + if not ( + all(comp_dist.dtype in continuous_types for comp_dist in comp_dists) + or all(comp_dist.dtype in discrete_types for comp_dist in comp_dists) + ): + raise ValueError( + "All distributions in comp_dists must be either discrete or continuous.\n" + "See the following issue for more information: https://github.com/pymc-devs/pymc/issues/4511." + ) + # Check that components are not associated with a registered variable in the model components_ndim_supp = set() for dist in comp_dists: diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index 13898e6511..2cab2151ae 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -721,6 +721,21 @@ def test_broadcast_components(self, comp_dists, expected_shape): assert isinstance(comp_dist.owner.op, RandomVariable) assert tuple(comp_dist.shape.eval()) == expected_shape + def test_preventing_mixing_cont_and_discrete(self): + with pytest.raises( + ValueError, + match="All distributions in comp_dists must be either discrete or continuous.", + ): + with Model() as model: + mix = Mixture( + "x", + w=[0.5, 0.3, 0.2], + comp_dists=[ + Categorical.dist(np.tile(1 / 3, 3)), + Normal.dist(np.ones(3), 3), + ], + ) + class TestNormalMixture(SeededTest): def test_normal_mixture_sampling(self):