Skip to content

WIP Mixture Models #1437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Oct 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Mixture Models

.. toctree::
notebooks/gaussian_mixture_model.ipynb
notebooks/marginalized_gaussian_mixture_model.ipynb
notebooks/gaussian-mixture-model-advi.ipynb
notebooks/dp_mix.ipynb

319 changes: 319 additions & 0 deletions docs/source/notebooks/marginalized_gaussian_mixture_model.ipynb

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
from .distribution import TensorType
from .distribution import draw_values

from .mixture import Mixture
from .mixture import NormalMixture

from .multivariate import MvNormal
from .multivariate import MvStudentT
from .multivariate import Dirichlet
Expand Down Expand Up @@ -112,5 +115,7 @@
'AR1',
'GaussianRandomWalk',
'GARCH11',
'SkewNormal'
'SkewNormal',
'Mixture',
'NormalMixture'
]
169 changes: 169 additions & 0 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
import theano.tensor as tt

from ..math import logsumexp
from .dist_math import bound
from .distribution import Discrete, Distribution, draw_values, generate_samples
from .continuous import get_tau_sd, Normal


def all_discrete(comp_dists):
"""
Determine if all distributions in comp_dists are discrete
"""
if isinstance(comp_dists, Distribution):
return isinstance(comp_dists, Discrete)
else:
return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists)


class Mixture(Distribution):
R"""
Mixture log-likelihood

Often used to model subpopulation heterogeneity

.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)

======== ============================================
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
======== ============================================

Parameters
----------
w : array of floats
w >= 0 and w <= 1
the mixutre weights
comp_dists : multidimensional PyMC3 distribution or iterable of one-dimensional PyMC3 distributions
the component distributions :math:`f_1, \ldots, f_n`
"""
def __init__(self, w, comp_dists, *args, **kwargs):
shape = kwargs.pop('shape', ())

self.w = w
self.comp_dists = comp_dists

defaults = kwargs.pop('defaults', [])

if all_discrete(comp_dists):
dtype = kwargs.pop('dtype', 'int64')
else:
dtype = kwargs.pop('dtype', 'float64')

try:
self.mean = (w * self._comp_means()).sum(axis=-1)

if 'mean' not in defaults:
defaults.append('mean')
except AttributeError:
pass

try:
comp_modes = self._comp_modes()
comp_mode_logps = self.logp(comp_modes)
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)]

if 'mode' not in defaults:
defaults.append('mode')
except AttributeError:
pass

super(Mixture, self).__init__(shape, dtype, defaults=defaults,
*args, **kwargs)

def _comp_logp(self, value):
comp_dists = self.comp_dists

try:
value_ = value if value.ndim > 1 else tt.shape_padright(value)

return comp_dists.logp(value_)
except AttributeError:
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists],
axis=1)

def _comp_means(self):
try:
return self.comp_dists.mean
except AttributeError:
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists],
axis=1)

def _comp_modes(self):
try:
return self.comp_dists.mode
except AttributeError:
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists],
axis=1)

def _comp_samples(self, point=None, size=None, repeat=None):
try:
samples = self.comp_dists.random(point=point, size=size, repeat=repeat)
except AttributeError:
samples = np.column_stack([comp_dist.random(point=point, size=size, repeat=repeat)
for comp_dist in self.comp_dists])

return np.squeeze(samples)

def logp(self, value):
w = self.w

return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1).sum(),
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1))

def random(self, point=None, size=None, repeat=None):
def random_choice(*args, **kwargs):
w = kwargs.pop('w')
w /= w.sum(axis=-1, keepdims=True)
k = w.shape[-1]

if w.ndim > 1:
return np.row_stack([np.random.choice(k, p=w_) for w_ in w])
else:
return np.random.choice(k, p=w, *args, **kwargs)

w = draw_values([self.w], point=point)

w_samples = generate_samples(random_choice,
w=w,
broadcast_shape=w.shape[:-1] or (1,),
dist_shape=self.shape,
size=size).squeeze()
comp_samples = self._comp_samples(point=point, size=size, repeat=repeat)

if comp_samples.ndim > 1:
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples])
else:
return np.squeeze(comp_samples[w_samples])


class NormalMixture(Mixture):
R"""
Normal mixture log-likelihood

.. math:: f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i

======== =======================================
Support :math:`x \in \mathbb{R}`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
Variance :math:`\sum_{i = 1}^n w_i^2 \sigma^2_i`
======== =======================================

Parameters
w : array of floats
w >= 0 and w <= 1
the mixutre weights
mu : array of floats
the component means
sd : array of floats
the component standard deviations
tau : array of floats
the component precisions
"""
def __init__(self, w, mu, *args, **kwargs):
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
sd=kwargs.pop('sd', None))

super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplicity to create a Normal Mixture here is really validating, nicely done.

*args, **kwargs)
114 changes: 114 additions & 0 deletions pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
from numpy.testing import assert_allclose

from .helpers import SeededTest
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample


# Generate data
def generate_normal_mixture_data(w, mu, sd, size=1000):
component = np.random.choice(w.size, size=size, p=w)
x = np.random.normal(mu[component], sd[component], size=size)

return x


def generate_poisson_mixture_data(w, mu, size=1000):
component = np.random.choice(w.size, size=size, p=w)
x = np.random.poisson(mu[component], size=size)

return x


class TestMixture(SeededTest):
@classmethod
def setUpClass(cls):
super(TestMixture, cls).setUpClass()

cls.norm_w = np.array([0.75, 0.25])
cls.norm_mu = np.array([0., 5.])
cls.norm_sd = np.ones_like(cls.norm_mu)
cls.norm_x = generate_normal_mixture_data(cls.norm_w, cls.norm_mu, cls.norm_sd, size=1000)

cls.pois_w = np.array([0.4, 0.6])
cls.pois_mu = np.array([5., 20.])
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000)

def test_mixture_list_of_normals(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.norm_w))

mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)

x_obs = Mixture('x_obs', w,
[Normal.dist(mu[0], tau=tau[0]),
Normal.dist(mu[1], tau=tau[1])],
observed=self.norm_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.norm_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.norm_mu),
rtol=0.1, atol=0.1)

def test_normal_mixture(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.norm_w))

mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)

x_obs = NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.norm_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.norm_mu),
rtol=0.1, atol=0.1)

def test_poisson_mixture(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.pois_w))

mu = Gamma('mu', 1., 1., shape=self.pois_w.size)

x_obs = Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.pois_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.pois_mu),
rtol=0.1, atol=0.1)

def test_mixture_list_of_poissons(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.pois_w))

mu = Gamma('mu', 1., 1., shape=self.pois_w.size)

x_obs = Mixture('x_obs', w,
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
observed=self.pois_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.pois_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.pois_mu),
rtol=0.1, atol=0.1)