Skip to content

Implement icdf for Univariate distribution #6528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2023
5 changes: 5 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ def logcdf(value, lower, upper):
msg="lower <= upper",
)

def icdf(value, lower, upper):
res = lower + (upper - lower) * value
res = check_icdf_value(res, value)
return check_icdf_parameters(res, lower < upper)


@_default_transform.register(Uniform)
def uniform_default_transform(op, rv):
Expand Down
31 changes: 18 additions & 13 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def check_icdf(
pymc_dist: Distribution,
paramdomains: Dict[str, Domain],
scipy_icdf: Callable,
skip_paramdomain_outside_edge_test=False,
decimal: Optional[int] = None,
n_samples: int = 100,
) -> None:
Expand All @@ -548,7 +549,7 @@ def check_icdf(
paramdomains : Dictionary of Parameter : Domain pairs
Supported domains of distribution parameters
scipy_icdf : Scipy icdf method
Scipy icdf (ppp) method of equivalent pymc_dist distribution
Scipy icdf (ppf) method of equivalent pymc_dist distribution
decimal : int, optional
Level of precision with which pymc_dist and scipy_icdf are compared.
Defaults to 6 for float64 and 3 for float32
Expand All @@ -557,6 +558,9 @@ def check_icdf(
are compared between pymc and scipy methods. If n_samples is below the
total number of combinations, a random subset is evaluated. Setting
n_samples = -1, will return all possible combinations. Defaults to 100
skip_paradomain_outside_edge_test : Bool
Whether to run test 2., which checks that pymc distribution icdf
returns nan for invalid parameter values outside the supported domain edge

"""
if decimal is None:
Expand Down Expand Up @@ -586,19 +590,20 @@ def check_icdf(
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
valid_params["q"] = valid_value

# Test pymc distribution raises ParameterValueError for parameters outside the
# supported domain edges (excluding edges)
invalid_params = find_invalid_scalar_params(paramdomains)
for invalid_param, invalid_edges in invalid_params.items():
for invalid_edge in invalid_edges:
if invalid_edge is None:
continue
if not skip_paramdomain_outside_edge_test:
# Test pymc distribution raises ParameterValueError for parameters outside the
# supported domain edges (excluding edges)
invalid_params = find_invalid_scalar_params(paramdomains)
for invalid_param, invalid_edges in invalid_params.items():
for invalid_edge in invalid_edges:
if invalid_edge is None:
continue

point = valid_params.copy()
point[invalid_param] = invalid_edge
with pytest.raises(ParameterValueError):
pymc_icdf(**point)
pytest.fail(f"test_params={point}")
point = valid_params.copy()
point[invalid_param] = invalid_edge
with pytest.raises(ParameterValueError):
pymc_icdf(**point)
pytest.fail(f"test_params={point}")

# Test that values below 0 or above 1 evaluate to nan
invalid_values = find_invalid_scalar_params({"q": domain})["q"]
Expand Down
12 changes: 10 additions & 2 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

import pymc as pm

from pymc.distributions.continuous import Normal, get_tau_sigma, interpolated
from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated
from pymc.distributions.dist_math import clipped_beta_rvs
from pymc.logprob.abstract import logcdf
from pymc.logprob.abstract import icdf, logcdf
from pymc.logprob.joint_logprob import logp
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import floatX
Expand Down Expand Up @@ -176,13 +176,21 @@ def test_uniform(self):
lambda value, lower, upper: st.uniform.logcdf(value, lower, upper - lower),
skip_paramdomain_outside_edge_test=True,
)
check_icdf(
pm.Uniform,
{"lower": -Rplusunif, "upper": Rplusunif},
lambda q, lower, upper: st.uniform.ppf(q=q, loc=lower, scale=upper - lower),
skip_paramdomain_outside_edge_test=True,
)
# Custom logp / logcdf check for invalid parameters
invalid_dist = pm.Uniform.dist(lower=1, upper=0)
with pytensor.config.change_flags(mode=Mode("py")):
with pytest.raises(ParameterValueError):
logp(invalid_dist, np.array(0.5)).eval()
with pytest.raises(ParameterValueError):
logcdf(invalid_dist, np.array(0.5)).eval()
with pytest.raises(ParameterValueError):
icdf(invalid_dist, np.array(0.5)).eval()

def test_triangular(self):
check_logp(
Expand Down