Skip to content

Logprob derivation for Min of continuous IID variables #6846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable
Expand Down Expand Up @@ -122,7 +126,102 @@ def max_logprob(op, values, base_rv, **kwargs):
logcdf = _logcdf_helper(base_rv, value)

[n] = constant_fold([base_rv.size])

logprob = (n - 1) * logcdf + logprob + pt.math.log(n)

return logprob


class MeasurableMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


MeasurableVariable.register(MeasurableMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMaxNeg):
return None # pragma: no cover

base_var = node.inputs[0]

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
return None

# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
):
return None
except NotScalarConstantError:
return None

base_rv = base_var.owner.inputs[0]

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
return None

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None

measurable_min = MeasurableMaxNeg(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs

return min_rv


measurable_ir_rewrites_db.register(
"find_measurable_max_neg",
find_measurable_max_neg,
"basic",
"min",
)


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_var, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values
base_rv = base_var.owner.inputs[0]

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)

[n] = constant_fold([base_rv.size])
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

return logprob
127 changes: 105 additions & 22 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import pymc as pm

from pymc import logp
from pymc.logprob import conditional_logp
from pymc.testing import assert_no_rvs


Expand All @@ -58,55 +57,90 @@ def test_argmax():
x_max_logprob = logp(x_max, x_max_value)


def test_max_non_iid_fails():
"""Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_iid_fails(pt_op):
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_max_non_rv_fails():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_rv_fails(pt_op):
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_max_multivariate_rv_fails():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_multivariate_rv_fails(pt_op):
_alpha = pt.scalar()
_k = pt.iscalar()
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_max_categorical():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_categorical(pt_op):
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_non_supp_axis_max():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_supp_axis(pt_op):
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
x = pt.random.normal(0, 1, size=(3, 3))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -147,3 +181,52 @@ def test_max_logprob(shape, value, axis):
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)


@pytest.mark.parametrize(
"shape, value, axis",
[
(3, 0.85, -1),
(3, 0.01, 0),
(2, 0.2, None),
(4, 0.5, 0),
((3, 4), 0.9, None),
((3, 4), 0.75, (1, 0)),
],
)
def test_min_logprob(shape, value, axis):
"""Test whether the logprob for ```pt.mix``` produces the corrected
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
for all 1<=k<=n
"""
x = pt.random.uniform(0, 1, size=shape)
x.name = "x"
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

assert_no_rvs(x_min_logprob)

test_value = value

n = np.prod(shape)
beta_rv = pt.random.beta(1, n, name="beta")
beta_vv = beta_rv.clone()
beta_rv_logprob = logp(beta_rv, beta_vv)

np.testing.assert_allclose(
beta_rv_logprob.eval({beta_vv: test_value}),
(x_min_logprob.eval({x_min_value: test_value})),
rtol=1e-06,
)


def test_min_non_mul_elemwise_fails():
"""Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly"""
x = pt.log(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_min = pt.min(x, axis=-1)
x_min_value = pt.vector("x_min_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_min_logprob = logp(x_min, x_min_value)