Skip to content

Commit b521e64

Browse files
Implement BetaNegativeBinomial distribution (#258)
* Also broadcast parameters in Skellam when size is not provided Co-authored-by: Ricardo Vieira <[email protected]>
1 parent e8ce688 commit b521e64

File tree

4 files changed

+205
-3
lines changed

4 files changed

+205
-3
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Distributions
3131
Chi
3232
DiscreteMarkovChain
3333
GeneralizedPoisson
34+
BetaNegativeBinomial
3435
GenExtreme
3536
R2D2M2CP
3637
Skellam

pymc_experimental/distributions/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,22 @@
1818
"""
1919

2020
from pymc_experimental.distributions.continuous import Chi, GenExtreme
21-
from pymc_experimental.distributions.discrete import GeneralizedPoisson, Skellam
21+
from pymc_experimental.distributions.discrete import (
22+
BetaNegativeBinomial,
23+
GeneralizedPoisson,
24+
Skellam,
25+
)
2226
from pymc_experimental.distributions.histogram_utils import histogram_approximation
2327
from pymc_experimental.distributions.multivariate import R2D2M2CP
2428
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
2529

2630
__all__ = [
31+
"BetaNegativeBinomial",
2732
"DiscreteMarkovChain",
2833
"GeneralizedPoisson",
2934
"GenExtreme",
3035
"R2D2M2CP",
36+
"Skellam",
3137
"histogram_approximation",
3238
"Chi",
3339
]

pymc_experimental/distributions/discrete.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616
import pymc as pm
17-
from pymc.distributions.dist_math import check_parameters, factln, logpow
17+
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
1818
from pymc.distributions.shape_utils import rv_size_is_none
1919
from pytensor import tensor as pt
2020
from pytensor.tensor.random.op import RandomVariable
@@ -173,6 +173,125 @@ def logp(value, mu, lam):
173173
)
174174

175175

176+
class BetaNegativeBinomial:
177+
R"""
178+
Beta Negative Binomial distribution.
179+
180+
The pmf of this distribution is
181+
182+
.. math::
183+
184+
f(x \mid \alpha, \beta, r) = \frac{B(r + x, \alpha + \beta)}{B(r, \alpha)} \frac{\Gamma(x + \beta)}{x! \Gamma(\beta)}
185+
186+
where :math:`B` is the Beta function and :math:`\Gamma` is the Gamma function.
187+
188+
For more information, see https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution.
189+
190+
.. plot::
191+
:context: close-figs
192+
193+
import matplotlib.pyplot as plt
194+
import numpy as np
195+
from scipy.special import betaln, gammaln
196+
def factln(x):
197+
return gammaln(x + 1)
198+
def logp(x, alpha, beta, r):
199+
return (
200+
betaln(r + x, alpha + beta)
201+
- betaln(r, alpha)
202+
+ gammaln(x + beta)
203+
- factln(x)
204+
- gammaln(beta)
205+
)
206+
plt.style.use('arviz-darkgrid')
207+
x = np.arange(0, 25)
208+
params = [
209+
(1, 1, 1),
210+
(1, 1, 10),
211+
(1, 10, 1),
212+
(1, 10, 10),
213+
(10, 10, 10),
214+
]
215+
for alpha, beta, r in params:
216+
pmf = np.exp(logp(x, alpha, beta, r))
217+
plt.plot(x, pmf, "-o", label=r'$alpha$ = {}, $beta$ = {}, $r$ = {}'.format(alpha, beta, r))
218+
plt.xlabel('x', fontsize=12)
219+
plt.ylabel('f(x)', fontsize=12)
220+
plt.legend(loc=1)
221+
plt.show()
222+
223+
======== ======================================
224+
Support :math:`x \in \mathbb{N}_0`
225+
Mean :math:`{\begin{cases}{\frac {r\beta }{\alpha -1}}&{\text{if}}\ \alpha >1\\\infty &{\text{otherwise}}\ \end{cases}}`
226+
Variance :math:`{\displaystyle {\begin{cases}{\frac {r\beta (r+\alpha -1)(\beta +\alpha -1)}{(\alpha -2){(\alpha -1)}^{2}}}&{\text{if}}\ \alpha >2\\\infty &{\text{otherwise}}\ \end{cases}}}`
227+
======== ======================================
228+
229+
Parameters
230+
----------
231+
alpha : tensor_like of float
232+
shape of the beta distribution (alpha > 0).
233+
beta : tensor_like of float
234+
shape of the beta distribution (beta > 0).
235+
r : tensor_like of float
236+
number of successes until the experiment is stopped (integer but can be extended to real)
237+
"""
238+
239+
@staticmethod
240+
def beta_negative_binomial_dist(alpha, beta, r, size):
241+
if rv_size_is_none(size):
242+
alpha, beta, r = pt.broadcast_arrays(alpha, beta, r)
243+
244+
p = pm.Beta.dist(alpha, beta, size=size)
245+
return pm.NegativeBinomial.dist(p, r, size=size)
246+
247+
@staticmethod
248+
def beta_negative_binomial_logp(value, alpha, beta, r):
249+
res = (
250+
betaln(r + value, alpha + beta)
251+
- betaln(r, alpha)
252+
+ pt.gammaln(value + beta)
253+
- factln(value)
254+
- pt.gammaln(beta)
255+
)
256+
res = pt.switch(
257+
pt.lt(value, 0),
258+
-np.inf,
259+
res,
260+
)
261+
262+
return check_parameters(
263+
res,
264+
alpha > 0,
265+
beta > 0,
266+
r > 0,
267+
msg="alpha > 0, beta > 0, r > 0",
268+
)
269+
270+
def __new__(cls, name, alpha, beta, r, **kwargs):
271+
return pm.CustomDist(
272+
name,
273+
alpha,
274+
beta,
275+
r,
276+
dist=cls.beta_negative_binomial_dist,
277+
logp=cls.beta_negative_binomial_logp,
278+
class_name="BetaNegativeBinomial",
279+
**kwargs,
280+
)
281+
282+
@classmethod
283+
def dist(cls, alpha, beta, r, **kwargs):
284+
return pm.CustomDist.dist(
285+
alpha,
286+
beta,
287+
r,
288+
dist=cls.beta_negative_binomial_dist,
289+
logp=cls.beta_negative_binomial_logp,
290+
class_name="BetaNegativeBinomial",
291+
**kwargs,
292+
)
293+
294+
176295
class Skellam:
177296
R"""
178297
Skellam distribution.
@@ -228,6 +347,9 @@ class Skellam:
228347

229348
@staticmethod
230349
def skellam_dist(mu1, mu2, size):
350+
if rv_size_is_none(size):
351+
mu1, mu2 = pt.broadcast_arrays(mu1, mu2)
352+
231353
return pm.Poisson.dist(mu=mu1, size=size) - pm.Poisson.dist(mu=mu2, size=size)
232354

233355
@staticmethod

pymc_experimental/tests/distributions/test_discrete.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
)
3030
from pytensor import config
3131

32-
from pymc_experimental.distributions import GeneralizedPoisson, Skellam
32+
from pymc_experimental.distributions import (
33+
BetaNegativeBinomial,
34+
GeneralizedPoisson,
35+
Skellam,
36+
)
3337

3438

3539
class TestGeneralizedPoisson:
@@ -122,6 +126,75 @@ def test_moment(self, mu, lam, size, expected):
122126
assert_moment_is_expected(model, expected)
123127

124128

129+
class TestBetaNegativeBinomial:
130+
"""
131+
Wrapper class so that tests of experimental additions can be dropped into
132+
PyMC directly on adoption.
133+
"""
134+
135+
def test_logp(self):
136+
"""
137+
138+
Beta Negative Binomial logp function test values taken from R package as
139+
there is currently no implementation in scipy.
140+
https://github.com/scipy/scipy/issues/17330
141+
142+
The test values can be generated in R with the following code:
143+
144+
.. code-block:: r
145+
146+
library(extraDistr)
147+
148+
create.test.rows <- function(alpha, beta, r, x) {
149+
logp <- dbnbinom(x, alpha, beta, r, log=TRUE)
150+
paste0("(", paste(alpha, beta, r, x, logp, sep=", "), ")")
151+
}
152+
153+
x <- c(0, 1, 250, 5000)
154+
print(create.test.rows(1, 1, 1, x), quote=FALSE)
155+
print(create.test.rows(1, 1, 10, x), quote=FALSE)
156+
print(create.test.rows(1, 10, 1, x), quote=FALSE)
157+
print(create.test.rows(10, 1, 1, x), quote=FALSE)
158+
print(create.test.rows(10, 10, 10, x), quote=FALSE)
159+
160+
"""
161+
alpha, beta, r, value = pt.scalars("alpha", "beta", "r", "value")
162+
logp = pm.logp(BetaNegativeBinomial.dist(alpha, beta, r), value)
163+
logp_fn = pytensor.function([value, alpha, beta, r], logp)
164+
165+
tests = [
166+
# 1, 1, 1
167+
(1, 1, 1, 0, -0.693147180559945),
168+
(1, 1, 1, 1, -1.79175946922805),
169+
(1, 1, 1, 250, -11.0548820266432),
170+
(1, 1, 1, 5000, -17.0349862828565),
171+
# 1, 1, 10
172+
(1, 1, 10, 0, -2.39789527279837),
173+
(1, 1, 10, 1, -2.58021682959232),
174+
(1, 1, 10, 250, -8.82261694534392),
175+
(1, 1, 10, 5000, -14.7359968760473),
176+
# 1, 10, 1
177+
(1, 10, 1, 0, -2.39789527279837),
178+
(1, 10, 1, 1, -2.58021682959232),
179+
(1, 10, 1, 250, -8.82261694534418),
180+
(1, 10, 1, 5000, -14.7359968760446),
181+
# 10, 1, 1
182+
(10, 1, 1, 0, -0.0953101798043248),
183+
(10, 1, 1, 1, -2.58021682959232),
184+
(10, 1, 1, 250, -43.5891148758123),
185+
(10, 1, 1, 5000, -76.2953173311091),
186+
# 10, 10, 10
187+
(10, 10, 10, 0, -5.37909807285049),
188+
(10, 10, 10, 1, -4.17512526852455),
189+
(10, 10, 10, 250, -21.781591505836),
190+
(10, 10, 10, 5000, -53.4836799634603),
191+
]
192+
for test_alpha, test_beta, test_r, test_value, expected_logp in tests:
193+
np.testing.assert_allclose(
194+
logp_fn(test_value, test_alpha, test_beta, test_r), expected_logp
195+
)
196+
197+
125198
class TestSkellam:
126199
def test_logp(self):
127200
check_logp(

0 commit comments

Comments
 (0)