-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
WIP Mixture Models #1437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
WIP Mixture Models #1437
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
2a01506
First pass at mixture modelling
AustinRochford 8cb4bd5
No longer necessary to reference self.comp_dists directly in logp
AustinRochford 6fe4efa
Add dimension internally (when necessary)
AustinRochford b60a643
Import get_tau_sd
AustinRochford af63fde
Misc bugfixes
AustinRochford aa23b90
Add sampling to Mixtures
AustinRochford fb34ceb
Differentiate between Discrete and Continuous mixtures when possible
AustinRochford c30e358
Add support for 2D weights
AustinRochford 4dfd130
Gracefully try to calculate mean and mode defaults
AustinRochford 71bae8b
Add docstrings for Mixture classes
AustinRochford a4e722b
Export mixture models
AustinRochford 0785acd
Reference self.comp_dists
AustinRochford beedb34
Remove unnecessary pm.
AustinRochford 4222acd
Add Mixture tests
AustinRochford 1db7c25
Add missing imports
AustinRochford 2cf8121
Add marginalized Gaussian mixture model example
AustinRochford ef4a817
Calculate the mode of the mixture distribution correctly
AustinRochford File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
319 changes: 319 additions & 0 deletions
319
docs/source/notebooks/marginalized_gaussian_mixture_model.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import numpy as np | ||
import theano.tensor as tt | ||
|
||
from ..math import logsumexp | ||
from .dist_math import bound | ||
from .distribution import Discrete, Distribution, draw_values, generate_samples | ||
from .continuous import get_tau_sd, Normal | ||
|
||
|
||
def all_discrete(comp_dists): | ||
""" | ||
Determine if all distributions in comp_dists are discrete | ||
""" | ||
if isinstance(comp_dists, Distribution): | ||
return isinstance(comp_dists, Discrete) | ||
else: | ||
return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists) | ||
|
||
|
||
class Mixture(Distribution): | ||
R""" | ||
Mixture log-likelihood | ||
|
||
Often used to model subpopulation heterogeneity | ||
|
||
.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i) | ||
|
||
======== ============================================ | ||
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)` | ||
Mean :math:`\sum_{i = 1}^n w_i \mu_i` | ||
======== ============================================ | ||
|
||
Parameters | ||
---------- | ||
w : array of floats | ||
w >= 0 and w <= 1 | ||
the mixutre weights | ||
comp_dists : multidimensional PyMC3 distribution or iterable of one-dimensional PyMC3 distributions | ||
the component distributions :math:`f_1, \ldots, f_n` | ||
""" | ||
def __init__(self, w, comp_dists, *args, **kwargs): | ||
shape = kwargs.pop('shape', ()) | ||
|
||
self.w = w | ||
self.comp_dists = comp_dists | ||
|
||
defaults = kwargs.pop('defaults', []) | ||
|
||
if all_discrete(comp_dists): | ||
dtype = kwargs.pop('dtype', 'int64') | ||
else: | ||
dtype = kwargs.pop('dtype', 'float64') | ||
|
||
try: | ||
self.mean = (w * self._comp_means()).sum(axis=-1) | ||
|
||
if 'mean' not in defaults: | ||
defaults.append('mean') | ||
except AttributeError: | ||
pass | ||
|
||
try: | ||
comp_modes = self._comp_modes() | ||
comp_mode_logps = self.logp(comp_modes) | ||
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)] | ||
|
||
if 'mode' not in defaults: | ||
defaults.append('mode') | ||
except AttributeError: | ||
pass | ||
|
||
super(Mixture, self).__init__(shape, dtype, defaults=defaults, | ||
*args, **kwargs) | ||
|
||
def _comp_logp(self, value): | ||
comp_dists = self.comp_dists | ||
|
||
try: | ||
value_ = value if value.ndim > 1 else tt.shape_padright(value) | ||
|
||
return comp_dists.logp(value_) | ||
except AttributeError: | ||
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists], | ||
axis=1) | ||
|
||
def _comp_means(self): | ||
try: | ||
return self.comp_dists.mean | ||
except AttributeError: | ||
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists], | ||
axis=1) | ||
|
||
def _comp_modes(self): | ||
try: | ||
return self.comp_dists.mode | ||
except AttributeError: | ||
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists], | ||
axis=1) | ||
|
||
def _comp_samples(self, point=None, size=None, repeat=None): | ||
try: | ||
samples = self.comp_dists.random(point=point, size=size, repeat=repeat) | ||
except AttributeError: | ||
samples = np.column_stack([comp_dist.random(point=point, size=size, repeat=repeat) | ||
for comp_dist in self.comp_dists]) | ||
|
||
return np.squeeze(samples) | ||
|
||
def logp(self, value): | ||
w = self.w | ||
|
||
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1).sum(), | ||
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1)) | ||
|
||
def random(self, point=None, size=None, repeat=None): | ||
def random_choice(*args, **kwargs): | ||
w = kwargs.pop('w') | ||
w /= w.sum(axis=-1, keepdims=True) | ||
k = w.shape[-1] | ||
|
||
if w.ndim > 1: | ||
return np.row_stack([np.random.choice(k, p=w_) for w_ in w]) | ||
else: | ||
return np.random.choice(k, p=w, *args, **kwargs) | ||
|
||
w = draw_values([self.w], point=point) | ||
|
||
w_samples = generate_samples(random_choice, | ||
w=w, | ||
broadcast_shape=w.shape[:-1] or (1,), | ||
dist_shape=self.shape, | ||
size=size).squeeze() | ||
comp_samples = self._comp_samples(point=point, size=size, repeat=repeat) | ||
|
||
if comp_samples.ndim > 1: | ||
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples]) | ||
else: | ||
return np.squeeze(comp_samples[w_samples]) | ||
|
||
|
||
class NormalMixture(Mixture): | ||
R""" | ||
Normal mixture log-likelihood | ||
|
||
.. math:: f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i | ||
|
||
======== ======================================= | ||
Support :math:`x \in \mathbb{R}` | ||
Mean :math:`\sum_{i = 1}^n w_i \mu_i` | ||
Variance :math:`\sum_{i = 1}^n w_i^2 \sigma^2_i` | ||
======== ======================================= | ||
|
||
Parameters | ||
w : array of floats | ||
w >= 0 and w <= 1 | ||
the mixutre weights | ||
mu : array of floats | ||
the component means | ||
sd : array of floats | ||
the component standard deviations | ||
tau : array of floats | ||
the component precisions | ||
""" | ||
def __init__(self, w, mu, *args, **kwargs): | ||
_, sd = get_tau_sd(tau=kwargs.pop('tau', None), | ||
sd=kwargs.pop('sd', None)) | ||
|
||
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd), | ||
*args, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
|
||
from .helpers import SeededTest | ||
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample | ||
|
||
|
||
# Generate data | ||
def generate_normal_mixture_data(w, mu, sd, size=1000): | ||
component = np.random.choice(w.size, size=size, p=w) | ||
x = np.random.normal(mu[component], sd[component], size=size) | ||
|
||
return x | ||
|
||
|
||
def generate_poisson_mixture_data(w, mu, size=1000): | ||
component = np.random.choice(w.size, size=size, p=w) | ||
x = np.random.poisson(mu[component], size=size) | ||
|
||
return x | ||
|
||
|
||
class TestMixture(SeededTest): | ||
@classmethod | ||
def setUpClass(cls): | ||
super(TestMixture, cls).setUpClass() | ||
|
||
cls.norm_w = np.array([0.75, 0.25]) | ||
cls.norm_mu = np.array([0., 5.]) | ||
cls.norm_sd = np.ones_like(cls.norm_mu) | ||
cls.norm_x = generate_normal_mixture_data(cls.norm_w, cls.norm_mu, cls.norm_sd, size=1000) | ||
|
||
cls.pois_w = np.array([0.4, 0.6]) | ||
cls.pois_mu = np.array([5., 20.]) | ||
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000) | ||
|
||
def test_mixture_list_of_normals(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.norm_w)) | ||
|
||
mu = Normal('mu', 0., 10., shape=self.norm_w.size) | ||
tau = Gamma('tau', 1., 1., shape=self.norm_w.size) | ||
|
||
x_obs = Mixture('x_obs', w, | ||
[Normal.dist(mu[0], tau=tau[0]), | ||
Normal.dist(mu[1], tau=tau[1])], | ||
observed=self.norm_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.norm_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.norm_mu), | ||
rtol=0.1, atol=0.1) | ||
|
||
def test_normal_mixture(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.norm_w)) | ||
|
||
mu = Normal('mu', 0., 10., shape=self.norm_w.size) | ||
tau = Gamma('tau', 1., 1., shape=self.norm_w.size) | ||
|
||
x_obs = NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.norm_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.norm_mu), | ||
rtol=0.1, atol=0.1) | ||
|
||
def test_poisson_mixture(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.pois_w)) | ||
|
||
mu = Gamma('mu', 1., 1., shape=self.pois_w.size) | ||
|
||
x_obs = Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.pois_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.pois_mu), | ||
rtol=0.1, atol=0.1) | ||
|
||
def test_mixture_list_of_poissons(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.pois_w)) | ||
|
||
mu = Gamma('mu', 1., 1., shape=self.pois_w.size) | ||
|
||
x_obs = Mixture('x_obs', w, | ||
[Poisson.dist(mu[0]), Poisson.dist(mu[1])], | ||
observed=self.pois_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.pois_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.pois_mu), | ||
rtol=0.1, atol=0.1) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The simplicity to create a Normal Mixture here is really validating, nicely done.