diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index d597299997..0beba75b0f 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1309,6 +1309,16 @@ def logcdf(value, a, b): msg="a > 0, b > 0", ) + def icdf(value, a, b): + res = pt.exp(pt.reciprocal(a) * pt.log1mexp(pt.reciprocal(b) * pt.log1p(-value))) + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + a > 0, + b > 0, + msg="a > 0, b > 0", + ) + class Exponential(PositiveContinuous): r""" diff --git a/pymc/testing.py b/pymc/testing.py index 86f2910a2b..2cafa9e970 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -657,6 +657,90 @@ def check_selfconsistency_discrete_logcdf( ) +def check_selfconsistency_continuous_icdf( + distribution: Distribution, + paramdomains: Dict[str, Domain], + decimal: Optional[int] = None, + n_samples: int = 100, +) -> None: + """ + Check that the icdf and logcdf functions of the distribution are consistent for a sample of probability values. + """ + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + + dist = create_dist_from_paramdomains(distribution, paramdomains) + value = dist.type() + value.name = "value" + + dist_icdf = icdf(dist, value) + dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf) + + dist_logcdf = logcdf(dist, value) + dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf) + + domains = paramdomains.copy() + domains["value"] = Domain(np.linspace(0, 1, 10)) + + for point in product(domains, n_samples=n_samples): + point = dict(point) + value = point.pop("value") + + with pytensor.config.change_flags(mode=Mode("py")): + npt.assert_almost_equal( + value, + np.exp(dist_logcdf_fn(**point, value=dist_icdf_fn(**point, value=value))), + decimal=decimal, + err_msg=f"point: {point}, value: {value}", + ) + + +def check_selfconsistency_discrete_icdf( + distribution: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + decimal: Optional[int] = None, + n_samples: int = 100, +) -> None: + """ + Check that the icdf and logcdf functions of the distribution are + consistent for a sample of values in the domain of the + distribution. + """ + + def ftrunc(values, decimal=0): + return np.trunc(values * 10**decimal) / (10**decimal) + + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + + dist = create_dist_from_paramdomains(distribution, paramdomains) + + value = pt.TensorType(dtype="float64", shape=[])("value") + + dist_icdf = icdf(dist, value) + dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf) + + dist_logcdf = logcdf(dist, value) + dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf) + + domains = paramdomains.copy() + domains["value"] = domain + + for point in product(domains, n_samples=n_samples): + point = dict(point) + value = point.pop("value") + + with pytensor.config.change_flags(mode=Mode("py")): + expected_value = value + computed_value = dist_icdf_fn( + **point, value=ftrunc(np.exp(dist_logcdf_fn(**point, value=value)), decimal=decimal) + ) + assert ( + expected_value == computed_value + ), f"expected_value = {expected_value}, computed_value = {computed_value}, {point}" + + def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index b84a77e049..926fec5347 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -46,6 +46,7 @@ check_icdf, check_logcdf, check_logp, + check_selfconsistency_continuous_icdf, continuous_random_tester, seeded_numpy_distribution_builder, seeded_scipy_distribution_builder, @@ -424,6 +425,10 @@ def scipy_log_cdf(value, a, b): {"a": Rplus, "b": Rplus}, scipy_log_cdf, ) + check_selfconsistency_continuous_icdf( + pm.Kumaraswamy, + {"a": Rplusbig, "b": Rplusbig}, + ) def test_exponential(self): check_logp( diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index e9543c0746..d8812124df 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -51,6 +51,7 @@ check_icdf, check_logcdf, check_logp, + check_selfconsistency_discrete_icdf, check_selfconsistency_discrete_logcdf, seeded_numpy_distribution_builder, seeded_scipy_distribution_builder, @@ -119,6 +120,11 @@ def test_discrete_unif(self): lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1), skip_paramdomain_outside_edge_test=True, ) + check_selfconsistency_discrete_icdf( + pm.DiscreteUniform, + Rdunif, + {"lower": -Rplusdunif, "upper": Rplusdunif}, + ) # Custom logp / logcdf check for invalid parameters invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0) with pytensor.config.change_flags(mode=Mode("py")): @@ -152,6 +158,11 @@ def test_geometric(self): {"p": Unit}, st.geom.ppf, ) + check_selfconsistency_discrete_icdf( + pm.Geometric, + Nat, + {"p": Unit}, + ) def test_hypergeometric(self): def modified_scipy_hypergeom_logcdf(value, N, k, n):