diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 0b8d494276..5f751b5bf1 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -448,7 +448,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg backward_value = op.transform_elemwise.backward(value, *other_inputs) - # Some transformations, like squaring may produce multiple backward values + # Fail if transformation is not injective + # A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many if isinstance(backward_value, tuple): raise NotImplementedError @@ -469,6 +470,11 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) input_icdf = _icdf_helper(measurable_input, value) icdf = op.transform_elemwise.forward(input_icdf, *other_inputs) + # Fail if transformation is not injective + # A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many + if isinstance(op.transform_elemwise.backward(icdf, *other_inputs), tuple): + raise NotImplementedError + return icdf diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 785e2599fb..9960dff948 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -49,7 +49,7 @@ from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _logprob -from pymc.logprob.basic import conditional_logp, logp +from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, ArcsinhTransform, @@ -1080,3 +1080,37 @@ def test_check_jac_det(transform): elemwise=True, rv_var=pt.random.normal(0.5, 1, name="base_rv"), ) + + +def test_logcdf_measurable_transform(): + x = pt.exp(pt.random.uniform(0, 1)) + value = x.type() + logcdf_fn = pytensor.function([value], logcdf(x, value)) + + assert logcdf_fn(0) == -np.inf + np.testing.assert_almost_equal(logcdf_fn(np.exp(0.5)), np.log(0.5)) + np.testing.assert_almost_equal(logcdf_fn(5), 0) + + +def test_logcdf_measurable_non_injective_fails(): + x = pt.abs(pt.random.uniform(0, 1)) + value = x.type() + with pytest.raises(NotImplementedError): + logcdf(x, value) + + +def test_icdf_measurable_transform(): + x = pt.exp(pt.random.uniform(0, 1)) + value = x.type() + icdf_fn = pytensor.function([value], icdf(x, value)) + + np.testing.assert_almost_equal(icdf_fn(1e-16), 1) + np.testing.assert_almost_equal(icdf_fn(0.5), np.exp(0.5)) + np.testing.assert_almost_equal(icdf_fn(1 - 1e-16), np.e) + + +def test_icdf_measurable_non_injective_fails(): + x = pt.abs(pt.random.uniform(0, 1)) + value = x.type() + with pytest.raises(NotImplementedError): + icdf(x, value)