diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index be8d688d80..dcd1e1df5f 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -41,6 +41,10 @@ from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Max from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.var import TensorVariable @@ -122,7 +126,102 @@ def max_logprob(op, values, base_rv, **kwargs): logcdf = _logcdf_helper(base_rv, value) [n] = constant_fold([base_rv.size]) - logprob = (n - 1) * logcdf + logprob + pt.math.log(n) return logprob + + +class MeasurableMaxNeg(Max): + """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. + This shows up in the graph of min, which is (neg(max(neg(x))).""" + + +MeasurableVariable.register(MeasurableMaxNeg) + + +@node_rewriter(tracks=[Max]) +def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableMaxNeg): + return None # pragma: no cover + + base_var = node.inputs[0] + + if base_var.owner is None: + return None + + if not rv_map_feature.request_measurable(node.inputs): + return None + + # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise + if not isinstance(base_var.owner.op, Elemwise): + return None + + # negation is rv * (-1). Hence the scalar_op must be Mul + try: + if not ( + isinstance(base_var.owner.op.scalar_op, Mul) + and len(base_var.owner.inputs) == 2 + and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1 + ): + return None + except NotScalarConstantError: + return None + + base_rv = base_var.owner.inputs[0] + + # Non-univariate distributions and non-RVs must be rejected + if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0): + return None + + # TODO: We are currently only supporting continuous rvs + if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"): + return None + + # univariate i.i.d. test which also rules out other distributions + for params in base_rv.owner.inputs[3:]: + if params.type.ndim != 0: + return None + + # Check whether axis is supported or not + axis = set(node.op.axis) + base_var_dims = set(range(base_var.ndim)) + if axis != base_var_dims: + return None + + measurable_min = MeasurableMaxNeg(list(axis)) + min_rv_node = measurable_min.make_node(base_var) + min_rv = min_rv_node.outputs + + return min_rv + + +measurable_ir_rewrites_db.register( + "find_measurable_max_neg", + find_measurable_max_neg, + "basic", + "min", +) + + +@_logprob.register(MeasurableMaxNeg) +def max_neg_logprob(op, values, base_var, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation. + The formula that we use here is : + \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x)) + where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively. + """ + (value,) = values + base_rv = base_var.owner.inputs[0] + + logprob = _logprob_helper(base_rv, -value) + logcdf = _logcdf_helper(base_rv, -value) + + [n] = constant_fold([base_rv.size]) + logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n) + + return logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 5a3818716d..ff51199491 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -43,7 +43,6 @@ import pymc as pm from pymc import logp -from pymc.logprob import conditional_logp from pymc.testing import assert_no_rvs @@ -58,55 +57,90 @@ def test_argmax(): x_max_logprob = logp(x_max, x_max_value) -def test_max_non_iid_fails(): - """Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected""" +@pytest.mark.parametrize( + "pt_op", + [ + pt.max, + pt.min, + ], +) +def test_non_iid_fails(pt_op): + """Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected""" x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): - x_max_logprob = logp(x_max, x_max_value) + x_max_logprob = logp(x_m, x_m_value) -def test_max_non_rv_fails(): +@pytest.mark.parametrize( + "pt_op", + [ + pt.max, + pt.min, + ], +) +def test_non_rv_fails(pt_op): """Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected""" x = pt.exp(pt.random.beta(0, 1, size=(3,))) x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): - x_max_logprob = logp(x_max, x_max_value) + x_max_logprob = logp(x_m, x_m_value) -def test_max_multivariate_rv_fails(): +@pytest.mark.parametrize( + "pt_op", + [ + pt.max, + pt.min, + ], +) +def test_multivariate_rv_fails(pt_op): _alpha = pt.scalar() _k = pt.iscalar() x = pm.StickBreakingWeights.dist(_alpha, _k) x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): - x_max_logprob = logp(x_max, x_max_value) + x_max_logprob = logp(x_m, x_m_value) -def test_max_categorical(): +@pytest.mark.parametrize( + "pt_op", + [ + pt.max, + pt.min, + ], +) +def test_categorical(pt_op): """Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected""" x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): - x_max_logprob = logp(x_max, x_max_value) + x_max_logprob = logp(x_m, x_m_value) -def test_non_supp_axis_max(): +@pytest.mark.parametrize( + "pt_op", + [ + pt.max, + pt.min, + ], +) +def test_non_supp_axis(pt_op): """Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected""" x = pt.random.normal(0, 1, size=(3, 3)) x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): - x_max_logprob = logp(x_max, x_max_value) + x_max_logprob = logp(x_m, x_m_value) @pytest.mark.parametrize( @@ -147,3 +181,52 @@ def test_max_logprob(shape, value, axis): (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, ) + + +@pytest.mark.parametrize( + "shape, value, axis", + [ + (3, 0.85, -1), + (3, 0.01, 0), + (2, 0.2, None), + (4, 0.5, 0), + ((3, 4), 0.9, None), + ((3, 4), 0.75, (1, 0)), + ], +) +def test_min_logprob(shape, value, axis): + """Test whether the logprob for ```pt.mix``` produces the corrected + The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: + U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) + for all 1<=k<=n + """ + x = pt.random.uniform(0, 1, size=shape) + x.name = "x" + x_min = pt.min(x, axis=axis) + x_min_value = pt.scalar("x_min_value") + x_min_logprob = logp(x_min, x_min_value) + + assert_no_rvs(x_min_logprob) + + test_value = value + + n = np.prod(shape) + beta_rv = pt.random.beta(1, n, name="beta") + beta_vv = beta_rv.clone() + beta_rv_logprob = logp(beta_rv, beta_vv) + + np.testing.assert_allclose( + beta_rv_logprob.eval({beta_vv: test_value}), + (x_min_logprob.eval({x_min_value: test_value})), + rtol=1e-06, + ) + + +def test_min_non_mul_elemwise_fails(): + """Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly""" + x = pt.log(pt.random.beta(0, 1, size=(3,))) + x.name = "x" + x_min = pt.min(x, axis=-1) + x_min_value = pt.vector("x_min_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_min_logprob = logp(x_min, x_min_value)