diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index ad52ff72b0..fa39ac0370 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1058,6 +1058,15 @@ def logcdf(value, lower, upper): msg="lower <= upper", ) + def icdf(value, lower, upper): + res = pt.ceil(value * (upper - lower + 1)).astype("int64") + lower - 1 + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + lower <= upper, + msg="lower <= upper", + ) + class Categorical(Discrete): R""" diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index f3152268ea..78dbd7999b 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -29,7 +29,7 @@ import pymc as pm from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit -from pymc.logprob.abstract import logcdf +from pymc.logprob.abstract import icdf, logcdf from pymc.logprob.joint_logprob import logp from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import floatX @@ -118,6 +118,12 @@ def test_discrete_unif(self): Domain([-10, 0, 10], "int64"), {"lower": -Rplusdunif, "upper": Rplusdunif}, ) + check_icdf( + pm.DiscreteUniform, + {"lower": -Rplusdunif, "upper": Rplusdunif}, + lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1), + skip_paramdomain_outside_edge_test=True, + ) # Custom logp / logcdf check for invalid parameters invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0) with pytensor.config.change_flags(mode=Mode("py")): @@ -125,6 +131,8 @@ def test_discrete_unif(self): logp(invalid_dist, 0.5).eval() with pytest.raises(ParameterValueError): logcdf(invalid_dist, 2).eval() + with pytest.raises(ParameterValueError): + icdf(invalid_dist, np.array(1)).eval() def test_geometric(self): check_logp(