From bb752626e4ed0a6e3a89051bb449a1e086cf7a52 Mon Sep 17 00:00:00 2001 From: AustinRochford Date: Tue, 18 Oct 2016 08:23:54 -0400 Subject: [PATCH 1/2] Implement ZeroInflatedPoisson as a subclass of Mixture --- pymc3/distributions/mixture.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index a8e6f2e317..904de42f77 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -6,6 +6,8 @@ from .distribution import Discrete, Distribution, draw_values, generate_samples from .continuous import get_tau_sd, Normal +__all__ = ['Mixture', 'NormalMixture'] + def all_discrete(comp_dists): """ @@ -167,3 +169,26 @@ def __init__(self, w, mu, *args, **kwargs): super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd), *args, **kwargs) + + +class Zero(pm.Discrete): + def __init__(self, *args, **kwargs): + super(Zero, self).__init__(*args, **kwargs) + + def logp(self, value): + return tt.switch(tt.eq(value, 0), 0., -np.inf) + + def random(self, point=None, size=None, repeat=None): + def _random(dtype=self.dtype, size=None): + return np.full(size, fill_value=0, dtype=dtype) + + return generate_samples(_random, dist_shape=self.shape, + size=size).astype(self.dtype) + + +class ZeroInflatedPoisson(pm.Mixture): + def __init__(self, theta, psi, *args, **kwargs): + w = tt.stack([psi, 1 - psi]) + comp_dists = [Zero.dist(), pm.Poisson.dist(theta)] + + super(ZeroInflatedPoisson, self).__init__(w, comp_dists, *args, **kwargs) From f2133b8aa19839e5199315eb61b16377babb5667 Mon Sep 17 00:00:00 2001 From: AustinRochford Date: Tue, 18 Oct 2016 08:42:54 -0400 Subject: [PATCH 2/2] Remove unnecessary pm.s --- pymc3/distributions/mixture.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index 904de42f77..21b3a71aee 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -171,7 +171,7 @@ def __init__(self, w, mu, *args, **kwargs): *args, **kwargs) -class Zero(pm.Discrete): +class Zero(Discrete): def __init__(self, *args, **kwargs): super(Zero, self).__init__(*args, **kwargs) @@ -186,9 +186,9 @@ def _random(dtype=self.dtype, size=None): size=size).astype(self.dtype) -class ZeroInflatedPoisson(pm.Mixture): +class ZeroInflatedPoisson(Mixture): def __init__(self, theta, psi, *args, **kwargs): w = tt.stack([psi, 1 - psi]) - comp_dists = [Zero.dist(), pm.Poisson.dist(theta)] + comp_dists = [Zero.dist(), Poisson.dist(theta)] super(ZeroInflatedPoisson, self).__init__(w, comp_dists, *args, **kwargs)