Skip to content

Commit 7b08fc1

Browse files
authored
Add icdf functions for Lognormal, Half Cauchy and Half Normal distributions (#6766)
1 parent 14e673f commit 7b08fc1

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

pymc/distributions/continuous.py

+20
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pytensor.tensor.var import TensorConstant
5858

5959
from pymc.logprob.abstract import _logcdf_helper, _logprob_helper
60+
from pymc.logprob.basic import icdf
6061

6162
try:
6263
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
@@ -856,6 +857,11 @@ def logcdf(value, loc, sigma):
856857
msg="sigma > 0",
857858
)
858859

860+
def icdf(value, loc, sigma):
861+
res = icdf(Normal.dist(loc, sigma), (value + 1.0) / 2.0)
862+
res = check_icdf_value(res, value)
863+
return res
864+
859865

860866
class WaldRV(RandomVariable):
861867
name = "wald"
@@ -1714,12 +1720,17 @@ def logcdf(value, mu, sigma):
17141720
-np.inf,
17151721
normal_lcdf(mu, sigma, pt.log(value)),
17161722
)
1723+
17171724
return check_parameters(
17181725
res,
17191726
sigma > 0,
17201727
msg="sigma > 0",
17211728
)
17221729

1730+
def icdf(value, mu, sigma):
1731+
res = pt.exp(icdf(Normal.dist(mu, sigma), value))
1732+
return res
1733+
17231734

17241735
Lognormal = LogNormal
17251736

@@ -2121,6 +2132,15 @@ def logcdf(value, loc, beta):
21212132
msg="beta > 0",
21222133
)
21232134

2135+
def icdf(value, loc, beta):
2136+
res = loc + beta * pt.tan(np.pi * (value) / 2.0)
2137+
res = check_icdf_value(res, value)
2138+
return check_parameters(
2139+
res,
2140+
beta > 0,
2141+
msg="beta > 0",
2142+
)
2143+
21242144

21252145
class Gamma(PositiveContinuous):
21262146
r"""

tests/distributions/test_continuous.py

+23
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,11 @@ def test_half_normal(self):
299299
{"sigma": Rplus},
300300
lambda value, sigma: st.halfnorm.logcdf(value, scale=sigma),
301301
)
302+
check_icdf(
303+
pm.HalfNormal,
304+
{"sigma": Rplus},
305+
lambda q, sigma: st.halfnorm.ppf(q, scale=sigma),
306+
)
302307

303308
def test_chisquared_logp(self):
304309
check_logp(
@@ -502,6 +507,21 @@ def test_lognormal(self):
502507
{"mu": R, "sigma": Rplusbig},
503508
lambda value, mu, sigma: st.lognorm.logcdf(value, sigma, 0, np.exp(mu)),
504509
)
510+
check_icdf(
511+
pm.LogNormal,
512+
{"mu": R, "tau": Rplusbig},
513+
lambda q, mu, tau: floatX(st.lognorm.ppf(q, tau**-0.5, 0, np.exp(mu))),
514+
)
515+
# Because we exponentiate the normal quantile function, setting sigma >= 9.5
516+
# return extreme values that results in relative errors above 4 digits
517+
# we circumvent it by keeping it below or equal to 9.
518+
custom_rplusbig = Domain([0, 0.5, 0.9, 0.99, 1, 1.5, 2, 9, np.inf])
519+
check_icdf(
520+
pm.LogNormal,
521+
{"mu": R, "sigma": custom_rplusbig},
522+
lambda q, mu, sigma: floatX(st.lognorm.ppf(q, sigma, 0, np.exp(mu))),
523+
decimal=select_by_precision(float64=4, float32=3),
524+
)
505525

506526
def test_studentt_logp(self):
507527
check_logp(
@@ -567,6 +587,9 @@ def test_half_cauchy(self):
567587
{"beta": Rplusbig},
568588
lambda value, beta: st.halfcauchy.logcdf(value, scale=beta),
569589
)
590+
check_icdf(
591+
pm.HalfCauchy, {"beta": Rplusbig}, lambda q, beta: st.halfcauchy.ppf(q, scale=beta)
592+
)
570593

571594
def test_gamma_logp(self):
572595
check_logp(

0 commit comments

Comments
 (0)