From 3cfc91c1b46271ec66bf53205cac1dc857834efc Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 7 Aug 2023 00:28:44 +0530 Subject: [PATCH 1/9] Adding support for Discrete distribution for max logprob --- pymc/logprob/order.py | 50 ++++++++++++++++++++++++++++++------- tests/logprob/test_order.py | 9 +++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index f76428f83c..be536f8f7e 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -66,6 +66,13 @@ class MeasurableMax(Max): MeasurableVariable.register(MeasurableMax) +class MeasurableMaxDiscrete(Max): + """A placeholder used to specify a log-likelihood for a cmax sub-graph.""" + + +MeasurableVariable.register(MeasurableMaxDiscrete) + + @node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) @@ -87,10 +94,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0): return None - # TODO: We are currently only supporting continuous rvs - if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): - return None - # univariate i.i.d. test which also rules out other distributions for params in base_var.owner.inputs[3:]: if params.type.ndim != 0: @@ -102,11 +105,20 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if axis != base_var_dims: return None - measurable_max = MeasurableMax(list(axis)) - max_rv_node = measurable_max.make_node(base_var) - max_rv = max_rv_node.outputs + # logprob for discrete distribution + if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): + measurable_max = MeasurableMaxDiscrete(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs + + return max_rv + # logprob for continuous distribution + else: + measurable_max = MeasurableMax(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs - return max_rv + return max_rv measurable_ir_rewrites_db.register( @@ -131,6 +143,26 @@ def max_logprob(op, values, base_rv, **kwargs): return logprob +@_logprob.register(MeasurableMaxDiscrete) +def max_logprob_discrete(op, values, base_rv, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation. + + The formula that we use here is : + \ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) + where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively. + """ + (value,) = values + logprob = _logprob_helper(base_rv, value) + logcdf = _logcdf_helper(base_rv, value) + logcdf_prev = _logcdf_helper(base_rv, value - 1) + + n = base_rv.size + + logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n) + + return logprob + + class MeasurableMaxNeg(Max): """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. This shows up in the graph of min, which is (neg(max(neg(x))).""" @@ -224,4 +256,4 @@ def max_neg_logprob(op, values, base_var, **kwargs): [n] = constant_fold([base_rv.size]) logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n) - return logprob + return logprob \ No newline at end of file diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index ff51199491..9552e11b2d 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -230,3 +230,12 @@ def test_min_non_mul_elemwise_fails(): x_min_value = pt.vector("x_min_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_min_logprob = logp(x_min, x_min_value) + +def test_max_discrete(): + x = pm.DiscreteUniform.dist(0, 1, size=(3,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.scalar("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + x_max_logprob.eval({x_max_value: 0.85}) From eb2bfebef25a6a69249a557b574edb5803f7a6ac Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 7 Aug 2023 00:31:45 +0530 Subject: [PATCH 2/9] Adding support for Discrete distribution for max logprob --- pymc/logprob/order.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index be536f8f7e..5978008045 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -256,4 +256,4 @@ def max_neg_logprob(op, values, base_var, **kwargs): [n] = constant_fold([base_rv.size]) logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n) - return logprob \ No newline at end of file + return logprob From 322ea87ea6d9242cfc9a580a609282fa7d0fd19e Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 13 Aug 2023 13:09:41 +0530 Subject: [PATCH 3/9] Test for discrete --- tests/logprob/test_order.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 9552e11b2d..e20ba08374 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -43,6 +43,7 @@ import pymc as pm from pymc import logp +from pymc.logprob.abstract import _logcdf_helper, _logprob_helper from pymc.testing import assert_no_rvs @@ -238,4 +239,16 @@ def test_max_discrete(): x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) - x_max_logprob.eval({x_max_value: 0.85}) + discrete_logprob = _logprob_helper(x, x_max_value) + discrete_logcdf = _logcdf_helper(x, x_max_value) + discrete_logcdf_prev = _logcdf_helper(x, x_max_value - 1) + n = x.size + discrete_logprob = pt.log((pt.exp(discrete_logcdf)) ** n - (pt.exp(discrete_logcdf_prev)) ** n) + + test_value = 0.85 + + np.testing.assert_allclose( + discrete_logprob.eval({x_max_value: test_value}), + (x_max_logprob.eval({x_max_value: test_value})), + rtol=1e-06, + ) From 88b4abcea69eb8161625374a0170296e1b7a91c8 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 4 Sep 2023 18:59:27 +0530 Subject: [PATCH 4/9] Add support for discrete rvs --- pymc/logprob/order.py | 28 +++++++++++++++------------- tests/logprob/test_order.py | 25 ++++++++++++++----------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 5978008045..8db21ea703 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -49,6 +49,8 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable +import pymc as pm + from pymc.logprob.abstract import ( MeasurableVariable, _logcdf_helper, @@ -67,7 +69,7 @@ class MeasurableMax(Max): class MeasurableMaxDiscrete(Max): - """A placeholder used to specify a log-likelihood for a cmax sub-graph.""" + """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables""" MeasurableVariable.register(MeasurableMaxDiscrete) @@ -105,14 +107,14 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if axis != base_var_dims: return None - # logprob for discrete distribution - if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): - measurable_max = MeasurableMaxDiscrete(list(axis)) - max_rv_node = measurable_max.make_node(base_var) - max_rv = max_rv_node.outputs + # distinguish measurable discrete and continuous (because logprob is different) + if base_var.owner.op.dtype.startswith("int"): + if isinstance(base_var.owner.op, RandomVariable): + measurable_max = MeasurableMaxDiscrete(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs - return max_rv - # logprob for continuous distribution + return max_rv else: measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) @@ -148,17 +150,17 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. The formula that we use here is : - \ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) - where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively. + .. math:: + \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) + where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables. """ (value,) = values - logprob = _logprob_helper(base_rv, value) logcdf = _logcdf_helper(base_rv, value) logcdf_prev = _logcdf_helper(base_rv, value - 1) - n = base_rv.size + [n] = constant_fold([base_rv.size]) - logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n) + logprob = pm.math.logdiffexp(n * logcdf, n * logcdf_prev) return logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index e20ba08374..efe5a118ef 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -39,6 +39,7 @@ import numpy as np import pytensor.tensor as pt import pytest +import scipy.stats as sp import pymc as pm @@ -232,23 +233,25 @@ def test_min_non_mul_elemwise_fails(): with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_min_logprob = logp(x_min, x_min_value) -def test_max_discrete(): - x = pm.DiscreteUniform.dist(0, 1, size=(3,)) - x.name = "x" - x_max = pt.max(x, axis=-1) + +@pytest.mark.parametrize( + "mu, size, value, axis", + [(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)], +) +def test_max_discrete(mu, size, value, axis): + x = pm.Poisson.dist(name="x", mu=mu, size=(size)) + x_max = pt.max(x, axis=axis) x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) - discrete_logprob = _logprob_helper(x, x_max_value) - discrete_logcdf = _logcdf_helper(x, x_max_value) - discrete_logcdf_prev = _logcdf_helper(x, x_max_value - 1) - n = x.size - discrete_logprob = pt.log((pt.exp(discrete_logcdf)) ** n - (pt.exp(discrete_logcdf_prev)) ** n) + test_value = value - test_value = 0.85 + n = size + exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n + exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n np.testing.assert_allclose( - discrete_logprob.eval({x_max_value: test_value}), + np.log(exp_rv - exp_rv_prev), (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, ) From fb73bde3aadd323eb8f73691cc0059e40fa3eab5 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 10 Sep 2023 12:27:10 +0530 Subject: [PATCH 5/9] Suppport for discrete max/min --- pymc/logprob/order.py | 13 ++++++------- tests/logprob/test_order.py | 6 +++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 8db21ea703..c541537e6e 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -111,16 +111,15 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if base_var.owner.op.dtype.startswith("int"): if isinstance(base_var.owner.op, RandomVariable): measurable_max = MeasurableMaxDiscrete(list(axis)) - max_rv_node = measurable_max.make_node(base_var) - max_rv = max_rv_node.outputs - - return max_rv + else: + return None else: measurable_max = MeasurableMax(list(axis)) - max_rv_node = measurable_max.make_node(base_var) - max_rv = max_rv_node.outputs - return max_rv + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs + + return max_rv measurable_ir_rewrites_db.register( diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index efe5a118ef..c0fca8b198 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -236,7 +236,7 @@ def test_min_non_mul_elemwise_fails(): @pytest.mark.parametrize( "mu, size, value, axis", - [(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)], + [(2, 3, 0.85, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], ) def test_max_discrete(mu, size, value, axis): x = pm.Poisson.dist(name="x", mu=mu, size=(size)) @@ -247,8 +247,8 @@ def test_max_discrete(mu, size, value, axis): test_value = value n = size - exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n - exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n + exp_rv = sp.poisson(mu).cdf(test_value) ** n + exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n np.testing.assert_allclose( np.log(exp_rv - exp_rv_prev), From b45ef2a4a10fab5805b76adc1fc1bf88bb0d652b Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 10 Sep 2023 12:41:33 +0530 Subject: [PATCH 6/9] Suppport for discrete for order_stats --- tests/logprob/test_order.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index c0fca8b198..e5c4736bfd 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -44,7 +44,6 @@ import pymc as pm from pymc import logp -from pymc.logprob.abstract import _logcdf_helper, _logprob_helper from pymc.testing import assert_no_rvs From ba726dcfcef66245c1819f1889e0f5322b5da73f Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 26 Sep 2023 15:55:02 +0530 Subject: [PATCH 7/9] Improved imports --- pymc/logprob/order.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index c541537e6e..18c84526b2 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -49,8 +49,6 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable -import pymc as pm - from pymc.logprob.abstract import ( MeasurableVariable, _logcdf_helper, @@ -58,6 +56,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.math import logdiffexp from pymc.pytensorf import constant_fold @@ -159,7 +158,7 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): [n] = constant_fold([base_rv.size]) - logprob = pm.math.logdiffexp(n * logcdf, n * logcdf_prev) + logprob = logdiffexp(n * logcdf, n * logcdf_prev) return logprob From 9fddde65899352a8c9e21d340bdd35c6757f2996 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 17 Oct 2023 11:04:27 +0530 Subject: [PATCH 8/9] changed test values for discrete order logprob --- tests/logprob/test_order.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index e5c4736bfd..8eae026c0b 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -235,7 +235,7 @@ def test_min_non_mul_elemwise_fails(): @pytest.mark.parametrize( "mu, size, value, axis", - [(2, 3, 0.85, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], + [(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], ) def test_max_discrete(mu, size, value, axis): x = pm.Poisson.dist(name="x", mu=mu, size=(size)) From 894ff276b016c93eccd08ca6c9234e72ca540b87 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 17 Oct 2023 18:49:41 +0530 Subject: [PATCH 9/9] Reducing redundant code in discrete order --- pymc/logprob/order.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 18c84526b2..35b84542db 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -108,10 +108,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens # distinguish measurable discrete and continuous (because logprob is different) if base_var.owner.op.dtype.startswith("int"): - if isinstance(base_var.owner.op, RandomVariable): - measurable_max = MeasurableMaxDiscrete(list(axis)) - else: - return None + measurable_max = MeasurableMaxDiscrete(list(axis)) else: measurable_max = MeasurableMax(list(axis))