diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 9cffa3c111..b45eaeb96f 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -345,6 +345,11 @@ def logcdf(value, lower, upper): msg="lower <= upper", ) + def icdf(value, lower, upper): + res = lower + (upper - lower) * value + res = check_icdf_value(res, value) + return check_icdf_parameters(res, lower < upper) + @_default_transform.register(Uniform) def uniform_default_transform(op, rv): diff --git a/pymc/testing.py b/pymc/testing.py index ea3ccfc46f..3bb222222f 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -526,6 +526,7 @@ def check_icdf( pymc_dist: Distribution, paramdomains: Dict[str, Domain], scipy_icdf: Callable, + skip_paramdomain_outside_edge_test=False, decimal: Optional[int] = None, n_samples: int = 100, ) -> None: @@ -548,7 +549,7 @@ def check_icdf( paramdomains : Dictionary of Parameter : Domain pairs Supported domains of distribution parameters scipy_icdf : Scipy icdf method - Scipy icdf (ppp) method of equivalent pymc_dist distribution + Scipy icdf (ppf) method of equivalent pymc_dist distribution decimal : int, optional Level of precision with which pymc_dist and scipy_icdf are compared. Defaults to 6 for float64 and 3 for float32 @@ -557,6 +558,9 @@ def check_icdf( are compared between pymc and scipy methods. If n_samples is below the total number of combinations, a random subset is evaluated. Setting n_samples = -1, will return all possible combinations. Defaults to 100 + skip_paradomain_outside_edge_test : Bool + Whether to run test 2., which checks that pymc distribution icdf + returns nan for invalid parameter values outside the supported domain edge """ if decimal is None: @@ -586,19 +590,20 @@ def check_icdf( valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} valid_params["q"] = valid_value - # Test pymc distribution raises ParameterValueError for parameters outside the - # supported domain edges (excluding edges) - invalid_params = find_invalid_scalar_params(paramdomains) - for invalid_param, invalid_edges in invalid_params.items(): - for invalid_edge in invalid_edges: - if invalid_edge is None: - continue + if not skip_paramdomain_outside_edge_test: + # Test pymc distribution raises ParameterValueError for parameters outside the + # supported domain edges (excluding edges) + invalid_params = find_invalid_scalar_params(paramdomains) + for invalid_param, invalid_edges in invalid_params.items(): + for invalid_edge in invalid_edges: + if invalid_edge is None: + continue - point = valid_params.copy() - point[invalid_param] = invalid_edge - with pytest.raises(ParameterValueError): - pymc_icdf(**point) - pytest.fail(f"test_params={point}") + point = valid_params.copy() + point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): + pymc_icdf(**point) + pytest.fail(f"test_params={point}") # Test that values below 0 or above 1 evaluate to nan invalid_values = find_invalid_scalar_params({"q": domain})["q"] diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 8b4d2ef2b0..8b4484a66c 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -26,9 +26,9 @@ import pymc as pm -from pymc.distributions.continuous import Normal, get_tau_sigma, interpolated +from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated from pymc.distributions.dist_math import clipped_beta_rvs -from pymc.logprob.abstract import logcdf +from pymc.logprob.abstract import icdf, logcdf from pymc.logprob.joint_logprob import logp from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import floatX @@ -176,6 +176,12 @@ def test_uniform(self): lambda value, lower, upper: st.uniform.logcdf(value, lower, upper - lower), skip_paramdomain_outside_edge_test=True, ) + check_icdf( + pm.Uniform, + {"lower": -Rplusunif, "upper": Rplusunif}, + lambda q, lower, upper: st.uniform.ppf(q=q, loc=lower, scale=upper - lower), + skip_paramdomain_outside_edge_test=True, + ) # Custom logp / logcdf check for invalid parameters invalid_dist = pm.Uniform.dist(lower=1, upper=0) with pytensor.config.change_flags(mode=Mode("py")): @@ -183,6 +189,8 @@ def test_uniform(self): logp(invalid_dist, np.array(0.5)).eval() with pytest.raises(ParameterValueError): logcdf(invalid_dist, np.array(0.5)).eval() + with pytest.raises(ParameterValueError): + icdf(invalid_dist, np.array(0.5)).eval() def test_triangular(self): check_logp(