Skip to content

Commit 986738f

Browse files
Luke LBricardoV94
Luke LB
authored andcommitted
added normal logcdf func and new test domains
1 parent 72534c7 commit 986738f

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
8585
check_parameters,
8686
clipped_beta_rvs,
8787
i0e,
88+
log_diff_normal_cdf,
8889
log_normal,
8990
logpow,
9091
normal_lccdf,
@@ -743,6 +744,31 @@ def logp(value, mu, sigma, lower, upper):
743744

744745
return logp
745746

747+
def logcdf(value, mu, sigma, lower, upper):
748+
logcdf = log_diff_normal_cdf(mu, sigma, value, lower) - log_diff_normal_cdf(
749+
mu, sigma, upper, lower
750+
)
751+
752+
is_lower_bounded = not (
753+
isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))
754+
)
755+
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
756+
757+
if is_lower_bounded:
758+
logcdf = pt.switch(value < lower, -np.inf, logcdf)
759+
760+
if is_upper_bounded:
761+
logcdf = pt.switch(value <= upper, logcdf, 0.0)
762+
763+
if is_lower_bounded and is_upper_bounded:
764+
logcdf = check_parameters(
765+
logcdf,
766+
pt.le(lower, upper),
767+
msg="lower_bound <= upper_bound",
768+
)
769+
770+
return logcdf
771+
746772

747773
@_default_transform.register(TruncatedNormal)
748774
def truncated_normal_default_transform(op, rv):

tests/distributions/test_continuous.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Circ,
3838
Domain,
3939
R,
40+
Rminusbig,
4041
Rplus,
4142
Rplusbig,
4243
Rplusunif,
@@ -934,6 +935,11 @@ def scipy_logp(value, mu, sigma, lower, upper):
934935
value, (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma
935936
)
936937

938+
def scipy_logcdf(value, mu, sigma, lower, upper):
939+
return st.truncnorm.logcdf(
940+
value, (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma
941+
)
942+
937943
check_logp(
938944
pm.TruncatedNormal,
939945
R,
@@ -961,6 +967,33 @@ def scipy_logp(value, mu, sigma, lower, upper):
961967
skip_paramdomain_outside_edge_test=True,
962968
)
963969

970+
check_logcdf(
971+
pm.TruncatedNormal,
972+
R,
973+
{"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig},
974+
scipy_logcdf,
975+
decimal=select_by_precision(float64=6, float32=1),
976+
skip_paramdomain_outside_edge_test=True,
977+
)
978+
979+
check_logcdf(
980+
pm.TruncatedNormal,
981+
R,
982+
{"mu": R, "sigma": Rplusbig, "upper": Rplusbig},
983+
ft.partial(scipy_logcdf, lower=-np.inf),
984+
decimal=select_by_precision(float64=6, float32=1),
985+
skip_paramdomain_outside_edge_test=True,
986+
)
987+
988+
check_logcdf(
989+
pm.TruncatedNormal,
990+
R,
991+
{"mu": R, "sigma": Rplusbig, "lower": -Rplusbig},
992+
ft.partial(scipy_logcdf, upper=np.inf),
993+
decimal=select_by_precision(float64=6, float32=1),
994+
skip_paramdomain_outside_edge_test=True,
995+
)
996+
964997
# This is a regression test for #6128: Check that having one out-of-bound value
965998
# in an input array does not set all logp values to -inf
966999
dist = pm.TruncatedNormal.dist(mu=1, sigma=2, lower=0, upper=3)

0 commit comments

Comments
 (0)