From 77f24ddbf67b33dafed5e03c674372c7c2baefde Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 11 Jun 2023 16:12:11 +0530 Subject: [PATCH 01/18] Order_stats --- pymc/logprob/__init__.py | 1 + pymc/logprob/order.py | 137 ++++++++++++++++++++++++++++++++++++ tests/logprob/test_order.py | 55 +++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 pymc/logprob/order.py create mode 100644 tests/logprob/test_order.py diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 0ddea90b6f..7c3c666917 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -49,6 +49,7 @@ import pymc.logprob.cumsum import pymc.logprob.checks import pymc.logprob.mixture +import pymc.logprob.order import pymc.logprob.scan import pymc.logprob.tensor import pymc.logprob.transforms diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py new file mode 100644 index 0000000000..28dc2b24ab --- /dev/null +++ b/pymc/logprob/order.py @@ -0,0 +1,137 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2021-2022 aesara-devs +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import List, Optional + +import numpy as np +import math +import pytensor.tensor as pt +import pytensor + +from pytensor.graph.basic import Node +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.math import Max, MaxAndArgmax + +from pymc.logprob.abstract import ( + MeasurableVariable, + _logcdf, + _logprob, +) +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import ignore_logprob + + +class MeasurableMax(MaxAndArgmax): + """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" + +MeasurableVariable.register(MeasurableMax) + +@node_rewriter([MaxAndArgmax]) +def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMax]]: + + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableMax): + return None # pragma: no cover + + + max_var = node.outputs[0] + base_var = node.inputs[0] + + + if not ( + base_var.owner + and isinstance(base_var.owner.op, MeasurableVariable) + and base_var not in rv_map_feature.rv_values + ): + return None + + if(base_var.owner.inputs[3].type.ndim != 0): + return None + + + # Make base_var unmeasurable + unmeasurable_base_var = ignore_logprob(base_var) + axis = node.op.axis + measurable_max = MeasurableMax(list(axis)) + max_rv_node = measurable_max.make_node(unmeasurable_base_var) + max_rv = max_rv_node.outputs + + max_rv[0].name = max_var.name + max_rv[1].name = node.outputs[1].name + + return max_rv + + +measurable_ir_rewrites_db.register( + "find_measurable_max", + find_measurable_max, + "basic", + "max", +) + + +@_logprob.register(MeasurableMax) +def max_logprob(op, values, base_rv, **kwargs): + + (value,) = values + + base_rv_op = base_rv.owner.op + base_rv_inputs = base_rv.owner.inputs + + logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs) + logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) + + if base_rv_op.name: + logprob.name = f"{base_rv_op}_logprob" + logcdf.name = f"{base_rv_op}_logcdf" + + size_var = base_rv.owner.inputs[1] + string_size = str(size_var) + for b in (0, len(string_size)-1): + if(string_size[b] == '}'): + a = string_size[b-1] + try : + n = int(a) + except ValueError: + return None + + logprob = (n-1)*logcdf + logprob + + return logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py new file mode 100644 index 0000000000..008a7088b6 --- /dev/null +++ b/tests/logprob/test_order.py @@ -0,0 +1,55 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2021-2022 aesara-devs +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np +import pytensor +import pytensor.tensor as pt +import pytest +import scipy.stats as st + +from pymc import logp +from pymc.logprob.basic import factorized_joint_logprob +from pymc.testing import assert_no_rvs + + +def test_max(): + x = pt.random.normal(0, 1, size=(3,)) + x_name = "x" + x_max = pt.max_and_argmax(x, axis=-1) + x_max_value = pt.vector("x_max_value") + x_max_logprob = logp(x_max[0], x_max_value) + + assert_no_rvs(x_max_logprob) \ No newline at end of file From 21544fdd5b5dd810ee4adeff060a81b805450c67 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 11 Jun 2023 16:59:50 +0530 Subject: [PATCH 02/18] logprob derivation for Max --- .github/workflows/tests.yml | 1 + pymc/logprob/order.py | 18 ++++-------------- scripts/run_mypy.py | 1 + 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e5ee774ad8..8f932e0747 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -110,6 +110,7 @@ jobs: tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py tests/logprob/test_mixture.py + tests/logprob/test_order.py tests/logprob/test_rewriting.py tests/logprob/test_scan.py tests/logprob/test_tensor.py diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 28dc2b24ab..5102839f90 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -52,7 +52,6 @@ _logprob, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db -from pymc.logprob.utils import ignore_logprob class MeasurableMax(MaxAndArgmax): @@ -71,30 +70,21 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas return None # pragma: no cover - max_var = node.outputs[0] base_var = node.inputs[0] - - if not ( - base_var.owner - and isinstance(base_var.owner.op, MeasurableVariable) - and base_var not in rv_map_feature.rv_values - ): - return None if(base_var.owner.inputs[3].type.ndim != 0): return None + if not rv_map_feature.request_measurable(node.inputs): + return None - # Make base_var unmeasurable - unmeasurable_base_var = ignore_logprob(base_var) + axis = node.op.axis measurable_max = MeasurableMax(list(axis)) - max_rv_node = measurable_max.make_node(unmeasurable_base_var) + max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs - max_rv[0].name = max_var.name - max_rv[1].name = node.outputs[1].name return max_rv diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 72f7013007..0dc73ef2a0 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -33,6 +33,7 @@ pymc/logprob/censoring.py pymc/logprob/basic.py pymc/logprob/mixture.py +pymc/logprob/order.py pymc/logprob/rewriting.py pymc/logprob/scan.py pymc/logprob/tensor.py From c01dc623ffdf4cbf3bdfacec5ee1da3b9e2d7afa Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 11 Jun 2023 17:07:56 +0530 Subject: [PATCH 03/18] pre-commit changes --- pymc/logprob/order.py | 37 ++++++++++++------------------------- tests/logprob/test_order.py | 9 ++------- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 5102839f90..8792f612c2 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -36,32 +36,24 @@ from typing import List, Optional -import numpy as np -import math -import pytensor.tensor as pt -import pytensor - from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.tensor.math import Max, MaxAndArgmax +from pytensor.tensor.math import MaxAndArgmax -from pymc.logprob.abstract import ( - MeasurableVariable, - _logcdf, - _logprob, -) +from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob from pymc.logprob.rewriting import measurable_ir_rewrites_db class MeasurableMax(MaxAndArgmax): """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" + MeasurableVariable.register(MeasurableMax) + @node_rewriter([MaxAndArgmax]) def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMax]]: - rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -69,23 +61,19 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas if isinstance(node.op, MeasurableMax): return None # pragma: no cover - base_var = node.inputs[0] - - if(base_var.owner.inputs[3].type.ndim != 0): + if base_var.owner.inputs[3].type.ndim != 0: return None - + if not rv_map_feature.request_measurable(node.inputs): return None - axis = node.op.axis - measurable_max = MeasurableMax(list(axis)) + measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs - return max_rv @@ -99,7 +87,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas @_logprob.register(MeasurableMax) def max_logprob(op, values, base_rv, **kwargs): - (value,) = values base_rv_op = base_rv.owner.op @@ -114,14 +101,14 @@ def max_logprob(op, values, base_rv, **kwargs): size_var = base_rv.owner.inputs[1] string_size = str(size_var) - for b in (0, len(string_size)-1): - if(string_size[b] == '}'): - a = string_size[b-1] - try : + for b in (0, len(string_size) - 1): + if string_size[b] == "}": + a = string_size[b - 1] + try: n = int(a) except ValueError: return None - logprob = (n-1)*logcdf + logprob + logprob = (n - 1) * logcdf + logprob return logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 008a7088b6..f188906c15 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -34,14 +34,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import numpy as np -import pytensor import pytensor.tensor as pt -import pytest -import scipy.stats as st from pymc import logp -from pymc.logprob.basic import factorized_joint_logprob from pymc.testing import assert_no_rvs @@ -51,5 +46,5 @@ def test_max(): x_max = pt.max_and_argmax(x, axis=-1) x_max_value = pt.vector("x_max_value") x_max_logprob = logp(x_max[0], x_max_value) - - assert_no_rvs(x_max_logprob) \ No newline at end of file + + assert_no_rvs(x_max_logprob) From d200fa18eeca711917aa92a03c0e72a97cbc9d51 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 18 Jun 2023 18:19:50 +0530 Subject: [PATCH 04/18] Draft changes --- pymc/logprob/order.py | 59 ++++++++++++++++++++++---------- pymc/logprob/rewriting.py | 3 ++ tests/logprob/test_order.py | 67 +++++++++++++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 19 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 8792f612c2..8f1a6e37b1 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -36,23 +36,26 @@ from typing import List, Optional +import pytensor.tensor as pt + from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.tensor.math import MaxAndArgmax +from pytensor.tensor.math import Max +from pytensor.tensor.random.op import RandomVariable from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob from pymc.logprob.rewriting import measurable_ir_rewrites_db -class MeasurableMax(MaxAndArgmax): - """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" +class MeasurableMax(Max): + """A placeholder used to specify a log-likelihood for a cmax sub-graph.""" MeasurableVariable.register(MeasurableMax) -@node_rewriter([MaxAndArgmax]) +@node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMax]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: @@ -63,13 +66,28 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas base_var = node.inputs[0] - if base_var.owner.inputs[3].type.ndim != 0: - return None + if isinstance(base_var.owner.op, RandomVariable): + for op in base_var.owner.inputs[3:]: + if op.type.ndim != 0: + return None if not rv_map_feature.request_measurable(node.inputs): return None axis = node.op.axis + for x in range(base_var.ndim): + if x not in axis: + return None + + axis = set(node.op.axis) + base_var_dims = set(range(base_var.ndim)) + if not axis.issubset(base_var_dims): + return None + + for ndim_param, param in zip(base_var.owner.op.ndims_params, base_var.owner.inputs[3:]): + if param.type.ndim != ndim_param: + return None + measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs @@ -87,6 +105,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas @_logprob.register(MeasurableMax) def max_logprob(op, values, base_rv, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation. + + The formula that we use here is : + \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(F(x)) + \ln(f(x)) + where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively. + """ (value,) = values base_rv_op = base_rv.owner.op @@ -99,16 +123,17 @@ def max_logprob(op, values, base_rv, **kwargs): logprob.name = f"{base_rv_op}_logprob" logcdf.name = f"{base_rv_op}_logcdf" - size_var = base_rv.owner.inputs[1] - string_size = str(size_var) - for b in (0, len(string_size) - 1): - if string_size[b] == "}": - a = string_size[b - 1] - try: - n = int(a) - except ValueError: - return None - - logprob = (n - 1) * logcdf + logprob + # size_var = base_rv.owner.inputs[1] + # string_size = str(size_var) + # for b in (0, len(string_size) - 1): + # if string_size[b] == "}": + # a = string_size[b - 1] + # try: + # n = int(a) + # except ValueError: + # return None + n = value.size + + logprob = (n - 1) * logcdf + logprob + pt.math.log(n) return logprob diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index d2dc6630fd..a2b4c3bfa2 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -57,6 +57,7 @@ EquilibriumGraphRewriter, GraphRewriter, node_rewriter, + out2in, ) from pytensor.graph.rewriting.db import ( LocalGroupDB, @@ -70,6 +71,7 @@ from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.rewriting.basic import register_canonicalize from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -358,6 +360,7 @@ def incsubtensor_rv_replace(fgraph, node): logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic") +logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic") # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index f188906c15..962241c858 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -34,7 +34,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import re + +import numpy as np +import pytensor import pytensor.tensor as pt +import pytest +import scipy.stats as st + +import pymc as pm from pymc import logp from pymc.testing import assert_no_rvs @@ -43,8 +51,63 @@ def test_max(): x = pt.random.normal(0, 1, size=(3,)) x_name = "x" - x_max = pt.max_and_argmax(x, axis=-1) + x_max = pt.max(x, axis=-1) x_max_value = pt.vector("x_max_value") - x_max_logprob = logp(x_max[0], x_max_value) + x_max_logprob = logp(x_max, x_max_value) assert_no_rvs(x_max_logprob) + + +def test_argmax(): + x = pt.random.normal(0, 1, size=(3,)) + x_name = "x" + x_max = pt.argmax(x, axis=-1) + x_max_value = pt.vector("x_max_value") + + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_max_non_iid_fails(): + x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) + x_name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +# def test_max_non_rv_fails(): +# x = pt.exp(pm.Beta.dist(1, 1)) +# x_name = "x" +# x_max = pt.max(x, axis=-1) +# x_max_value = pt.vector("x_max_value") +# # with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): +# x_max_logprob = logp(x_max, x_max_value) + + +def test_max_categorical(): + x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) + x_name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + # REASON for this failing is lines 71 - 74 and not the expected 89-91 which would pass? + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_max_logprob(): + x = pt.random.uniform(0, 1, size=(3,)) + x_name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + # pytensor.dprint(x_max_logprob) + beta_rv = pt.random.beta(0, 1, name="beta") + # pytensor.dprint(beta_rv) + + # assert np.isclose( + # expected_beta.eval(), + # x_max_logprob.eval({x_max_value: np.ones((3,)).max(axis= -1)}), + # ) + assert beta_rv in x_max_logprob From bccdd5e3d2c26b5f760900a70909fef4f8c60f8c Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 22 Jun 2023 17:07:02 +0530 Subject: [PATCH 05/18] Non RVS rejected --- pymc/logprob/order.py | 43 ++++++++++++++------------------- tests/logprob/test_order.py | 47 ++++++++++++++++++------------------- 2 files changed, 40 insertions(+), 50 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 8f1a6e37b1..1b792a2d06 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -44,7 +44,12 @@ from pytensor.tensor.math import Max from pytensor.tensor.random.op import RandomVariable -from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob +from pymc.logprob.abstract import ( + MeasurableVariable, + _logcdf_helper, + _logprob, + _logprob_helper, +) from pymc.logprob.rewriting import measurable_ir_rewrites_db @@ -66,6 +71,13 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas base_var = node.inputs[0] + if base_var.owner is None: + return None + + if not isinstance(base_var.owner.op, RandomVariable): + return None + + # univariate iid test which also rules out other distributions if isinstance(base_var.owner.op, RandomVariable): for op in base_var.owner.inputs[3:]: if op.type.ndim != 0: @@ -74,11 +86,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas if not rv_map_feature.request_measurable(node.inputs): return None - axis = node.op.axis - for x in range(base_var.ndim): - if x not in axis: - return None - axis = set(node.op.axis) base_var_dims = set(range(base_var.ndim)) if not axis.issubset(base_var_dims): @@ -113,26 +120,10 @@ def max_logprob(op, values, base_rv, **kwargs): """ (value,) = values - base_rv_op = base_rv.owner.op - base_rv_inputs = base_rv.owner.inputs - - logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs) - logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) - - if base_rv_op.name: - logprob.name = f"{base_rv_op}_logprob" - logcdf.name = f"{base_rv_op}_logcdf" - - # size_var = base_rv.owner.inputs[1] - # string_size = str(size_var) - # for b in (0, len(string_size) - 1): - # if string_size[b] == "}": - # a = string_size[b - 1] - # try: - # n = int(a) - # except ValueError: - # return None - n = value.size + logprob = _logprob_helper(base_rv, value) + logcdf = _logcdf_helper(base_rv, value) + + n = base_rv.size logprob = (n - 1) * logcdf + logprob + pt.math.log(n) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 962241c858..ed10fcd530 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -37,7 +37,6 @@ import re import numpy as np -import pytensor import pytensor.tensor as pt import pytest import scipy.stats as st @@ -50,7 +49,7 @@ def test_max(): x = pt.random.normal(0, 1, size=(3,)) - x_name = "x" + x.name = "x" x_max = pt.max(x, axis=-1) x_max_value = pt.vector("x_max_value") x_max_logprob = logp(x_max, x_max_value) @@ -60,7 +59,7 @@ def test_max(): def test_argmax(): x = pt.random.normal(0, 1, size=(3,)) - x_name = "x" + x.name = "x" x_max = pt.argmax(x, axis=-1) x_max_value = pt.vector("x_max_value") @@ -70,44 +69,44 @@ def test_argmax(): def test_max_non_iid_fails(): x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) - x_name = "x" + x.name = "x" x_max = pt.max(x, axis=-1) x_max_value = pt.vector("x_max_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_max, x_max_value) -# def test_max_non_rv_fails(): -# x = pt.exp(pm.Beta.dist(1, 1)) -# x_name = "x" -# x_max = pt.max(x, axis=-1) -# x_max_value = pt.vector("x_max_value") -# # with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): -# x_max_logprob = logp(x_max, x_max_value) +def test_max_non_rv_fails(): + x = pt.exp(pt.random.normal(0, 1, size=(3,))) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) def test_max_categorical(): x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) - x_name = "x" + x.name = "x" x_max = pt.max(x, axis=-1) x_max_value = pt.vector("x_max_value") - # REASON for this failing is lines 71 - 74 and not the expected 89-91 which would pass? with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_max, x_max_value) def test_max_logprob(): x = pt.random.uniform(0, 1, size=(3,)) - x_name = "x" + x.name = "x" x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) - # pytensor.dprint(x_max_logprob) - beta_rv = pt.random.beta(0, 1, name="beta") - # pytensor.dprint(beta_rv) - - # assert np.isclose( - # expected_beta.eval(), - # x_max_logprob.eval({x_max_value: np.ones((3,)).max(axis= -1)}), - # ) - assert beta_rv in x_max_logprob + + test_value = 0.85 # or a vector of your choice + + beta_rv = pt.random.beta(3, 1, name="beta") + beta_vv = beta_rv.clone() + beta_rv_logprob = logp(beta_rv, beta_vv) + + np.testing.assert_allclose( + beta_rv_logprob.eval({beta_vv: test_value}), (x_max_logprob.eval({x_max_value: test_value})) + ) From a6cdf60dee6305b27be52fa7d5c8f83be2924eb1 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Fri, 23 Jun 2023 21:25:15 +0530 Subject: [PATCH 06/18] Add edge test cases --- pymc/logprob/order.py | 14 ++++++------- tests/logprob/test_order.py | 39 ++++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 1b792a2d06..859c695185 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -74,27 +74,25 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas if base_var.owner is None: return None + # NonRVS must be rejected if not isinstance(base_var.owner.op, RandomVariable): return None # univariate iid test which also rules out other distributions if isinstance(base_var.owner.op, RandomVariable): - for op in base_var.owner.inputs[3:]: - if op.type.ndim != 0: + for params in base_var.owner.inputs[3:]: + if params.type.ndim != 0: return None if not rv_map_feature.request_measurable(node.inputs): return None + # Check whether axis is supported or not axis = set(node.op.axis) base_var_dims = set(range(base_var.ndim)) - if not axis.issubset(base_var_dims): + if axis != base_var_dims: return None - for ndim_param, param in zip(base_var.owner.op.ndims_params, base_var.owner.inputs[3:]): - if param.type.ndim != ndim_param: - return None - measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs @@ -116,7 +114,7 @@ def max_logprob(op, values, base_rv, **kwargs): The formula that we use here is : \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(F(x)) + \ln(f(x)) - where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively. + where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively. """ (value,) = values diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index ed10fcd530..81e62d7626 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -48,6 +48,7 @@ def test_max(): + """Test whether the logprob for ```pt.max``` is implemented""" x = pt.random.normal(0, 1, size=(3,)) x.name = "x" x_max = pt.max(x, axis=-1) @@ -57,7 +58,19 @@ def test_max(): assert_no_rvs(x_max_logprob) +def test_axis_max(): + """Test whether the rewrite takes into account ```None``` axis""" + x = pt.random.normal(0, 1) + x.name = "x" + x_max = pt.max(x, axis=None) + x_max_value = pt.vector("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + assert_no_rvs(x_max_logprob) + + def test_argmax(): + """Test whether the logprob for ```pt.argmax``` is rejected correctly""" x = pt.random.normal(0, 1, size=(3,)) x.name = "x" x_max = pt.argmax(x, axis=-1) @@ -68,6 +81,7 @@ def test_argmax(): def test_max_non_iid_fails(): + """Test whether the logprob for ```pt.max``` for non i.i.d is rejected correctly""" x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) x.name = "x" x_max = pt.max(x, axis=-1) @@ -77,7 +91,8 @@ def test_max_non_iid_fails(): def test_max_non_rv_fails(): - x = pt.exp(pt.random.normal(0, 1, size=(3,))) + """Test whether the logprob for ```pt.max``` for non RVs is rejected correctly""" + x = pt.exp(pt.random.beta(0, 1, size=(3,))) x.name = "x" x_max = pt.max(x, axis=-1) x_max_value = pt.vector("x_max_value") @@ -86,6 +101,7 @@ def test_max_non_rv_fails(): def test_max_categorical(): + """Test whether the logprob for ```pt.max``` for unsupported distributions is rejected correctly""" x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) x.name = "x" x_max = pt.max(x, axis=-1) @@ -94,7 +110,28 @@ def test_max_categorical(): x_max_logprob = logp(x_max, x_max_value) +def test_non_supp_axis_max(): + """Test whether the logprob for ```pt.max``` for unsupported axis is rejected correctly""" + x = pt.random.normal(0, 1, size=(3, 3)) + x.name = "x" + x_max_none = pt.max(x, axis=None) + x_max_value = pt.vector("x_max_value") + x_max_logprob = logp(x_max_none, x_max_value) + + x_max = pt.max(x, axis=-1) + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + assert_no_rvs(x_max_logprob) + + def test_max_logprob(): + """Test whether the logprob for ```pt.max``` produces the corrected + + The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: + U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) + for all 1<=k<=n + """ x = pt.random.uniform(0, 1, size=(3,)) x.name = "x" x_max = pt.max(x, axis=-1) From 616acb41d43f23fb639b6bd2146bcf011e906ded Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 29 Jun 2023 22:35:15 +0530 Subject: [PATCH 07/18] Return type rectified --- pymc/logprob/order.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 859c695185..dd2d09758b 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -43,6 +43,7 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.math import Max from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.var import TensorVariable from pymc.logprob.abstract import ( MeasurableVariable, @@ -61,7 +62,7 @@ class MeasurableMax(Max): @node_rewriter([Max]) -def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMax]]: +def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover From 1c41c3204b38b0f0bb919950dd1227bc4524b02f Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 3 Jul 2023 02:43:54 +0530 Subject: [PATCH 08/18] Doc updated --- docs/source/api/logprob.rst | 8 +++ docs/source/api/order_stats.rst | 41 ++++++++++++++++ pymc/logprob/order.py | 87 +++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 docs/source/api/order_stats.rst diff --git a/docs/source/api/logprob.rst b/docs/source/api/logprob.rst index d337c65c61..e4eda87f5f 100644 --- a/docs/source/api/logprob.rst +++ b/docs/source/api/logprob.rst @@ -21,3 +21,11 @@ Conditional probability conditional_logp transformed_conditional_logp + +Logarithmic probability +----------------------- + +.. toctree:: + :maxdepth: 2 + + order_stats diff --git a/docs/source/api/order_stats.rst b/docs/source/api/order_stats.rst new file mode 100644 index 0000000000..c520361b25 --- /dev/null +++ b/docs/source/api/order_stats.rst @@ -0,0 +1,41 @@ +================ +Order_Statistics +================ + +------------ +Introduction +------------ + +Users can derive the nth Order Statistic using PyMC for their custom DIstributions and the logarithmic probablity related to them. + +In PyMC users can derive their own custom distributions. Custom distribution refers to the ability to define and use probability distributions that are not included in the standard set of distributions provided. +While PyMC provides a wide range of common probability distributions (e.g., Normal, Bernoulli, etc.), there may be cases where you need to use a distribution that is not available by default. In such cases, you can create your own custom distribution using the pm.DensityDist class provided by PyMC. +Simplest way to define a Custom Distribution can be better understood from the following example: + +.. code-block:: python + + import numpy as np + import pymc as pm + from pytensor.tensor import TensorVariable + + def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: + return -(value - mu)**2 + + with pm.Model(): + mu = pm.Normal('mu',0,1) + pm.CustomDist( + 'custom_dist', + mu, + logp=logp, + observed=np.random.randn(100), + ) + idata = pm.sample(100) + +Here, we create a CustomDist that wraps a black-box logp function. This variable cannot be used in prior or posterior predictive sampling because no random function was provided. + +------------------------ +`Max` +------------------------ +Using PyMC and Pytensor, users can extract the maximum of a distribution and derive the log-probablity corresponding to this operation. + +.. autofunction:: pymc.logprob.order.max_logprob diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index dd2d09758b..c666f8509a 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -113,9 +113,96 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens def max_logprob(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. + Parameters + ---------- + op : Max-Op + values : tensor_like + rv : TensorVariable + + Returns + ------- + logprob : TensorVariable + + Examples + -------- + It is often desirable to find the Maximum from the distribution of random variables. + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.random.normal(0, 1, size=(3,)) + x.name = "x" + print(x.eval()) + #[0.61748772 1.08723759 0.98970957] + + x_max = pt.max(x, axis=None) + print(x_max.eval()) + # 1.087237592696084 + + It is only but natural that one might expect to derive the logarithmic probability corresponding to the Max operation. + The formula that we use here is : \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(F(x)) + \ln(f(x)) where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively. + + An example corresponding to this is illustrated below: + + .. code-block:: python + + import pytensor.tensor as pt + from pymc import logp + + x = pt.random.uniform(0, 1, size=(3,)) + x.name = "x" + # [0.09081509 0.84761712 0.59030273] + + x_max = pt.max(x, axis=-1) + # 0.8476171198716373 + + x_max_value = pt.scalar("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + test_value = x_max.eval() + + x_max_logprob.eval({x_max_value: test_value}) + # 0.7679597791946853 + + Currently our implementation has certain limitations which are mandated through some constraints. + + We only consider a distribution of RandomVariables and the logp function fails for NonRVs. + + .. code-block:: python + + import pytensor.tensor as pt + from pymc import logp + + x = pt.exp(pt.random.beta(0, 1, size=(3,))) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + The above code gives a Runtime error stating logprob method was not implemented as x in this case is a Non random variable distribution. + + We only consider independent and identically distributed random variables. + In probability theory and statistics, a collection of random variables is independent and identically distributed if each random variable has the same probability distribution as the others and all are mutually independent. + Hence the logp method fails for non-ids. + + .. code-block:: python + + import pytensor.tensor as pt + from pymc import logp + + x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + The above code gives a Runtime error stating logprob method was not implemented as x in this case is a Non-iid distribution. + + Note: We assume a very fluid definition of iid here.We assume only univariate distributions to be iids which rejects any multivariate distribution even though it might be iid by definition. + """ (value,) = values From d77be51483b2f4efecd639a5b9c228415f133a52 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 3 Jul 2023 20:04:08 +0530 Subject: [PATCH 09/18] Documentation changes --- docs/source/api/order_stats.rst | 3 ++- pymc/logprob/order.py | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/docs/source/api/order_stats.rst b/docs/source/api/order_stats.rst index c520361b25..fdd89bc299 100644 --- a/docs/source/api/order_stats.rst +++ b/docs/source/api/order_stats.rst @@ -6,7 +6,8 @@ Order_Statistics Introduction ------------ -Users can derive the nth Order Statistic using PyMC for their custom DIstributions and the logarithmic probablity related to them. +In statistics, the kth order statistic of a statistical sample is equal to its kth-smallest value. +In this section, we'll tackle how users can find the Logarithmic probability corresponding to the nth order statistic (maximum value) using PyMC for their own Custom distributions. In PyMC users can derive their own custom distributions. Custom distribution refers to the ability to define and use probability distributions that are not included in the standard set of distributions provided. While PyMC provides a wide range of common probability distributions (e.g., Normal, Bernoulli, etc.), there may be cases where you need to use a distribution that is not available by default. In such cases, you can create your own custom distribution using the pm.DensityDist class provided by PyMC. diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index c666f8509a..35a768f4af 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -55,7 +55,7 @@ class MeasurableMax(Max): - """A placeholder used to specify a log-likelihood for a cmax sub-graph.""" + """A placeholder used to specify a log-likelihood for a max sub-graph.""" MeasurableVariable.register(MeasurableMax) @@ -79,7 +79,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not isinstance(base_var.owner.op, RandomVariable): return None - # univariate iid test which also rules out other distributions + # univariate i.i.d. test which also rules out other distributions if isinstance(base_var.owner.op, RandomVariable): for params in base_var.owner.inputs[3:]: if params.type.ndim != 0: @@ -125,7 +125,10 @@ def max_logprob(op, values, base_rv, **kwargs): Examples -------- - It is often desirable to find the Maximum from the distribution of random variables. + It is often desirable to find the log-probability of the maximum of i.i.d. random variables. + + The "max of i.i.d. random variables" refers to finding the maximum value among a collection of random variables that are independent and identically distributed. + The example below illustrates how to find the Maximum from the distribution of random variables. .. code-block:: python @@ -140,7 +143,7 @@ def max_logprob(op, values, base_rv, **kwargs): print(x_max.eval()) # 1.087237592696084 - It is only but natural that one might expect to derive the logarithmic probability corresponding to the Max operation. + The log-probability of the maximum of i.i.d. random variables is a measure of the likelihood of observing a specific maximum value in a set of independent and identically distributed random variables. The formula that we use here is : \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(F(x)) + \ln(f(x)) @@ -182,11 +185,13 @@ def max_logprob(op, values, base_rv, **kwargs): x_max_value = pt.vector("x_max_value") x_max_logprob = logp(x_max, x_max_value) - The above code gives a Runtime error stating logprob method was not implemented as x in this case is a Non random variable distribution. + The above code gives a Runtime error stating logprob method was not implemented as x in this case is not a pure random variable. + A pure random variable in PyMC represents an unknown quantity in a Bayesian model and is associated with a prior distribution that is combined with the likelihood of observed data to obtain the posterior distribution through Bayesian inference + + We assume only univariate distributions as for multivariate variables, the concept of ordering is ambiguous since a "depth function" is required . - We only consider independent and identically distributed random variables. + We only consider independent and identically distributed random variables, for now. In probability theory and statistics, a collection of random variables is independent and identically distributed if each random variable has the same probability distribution as the others and all are mutually independent. - Hence the logp method fails for non-ids. .. code-block:: python @@ -201,7 +206,8 @@ def max_logprob(op, values, base_rv, **kwargs): The above code gives a Runtime error stating logprob method was not implemented as x in this case is a Non-iid distribution. - Note: We assume a very fluid definition of iid here.We assume only univariate distributions to be iids which rejects any multivariate distribution even though it might be iid by definition. + Note: We assume a very fluid definition of i.i.d. here. We say that an RV belongs to an i.i.d. if that RVs do not have different stochastic ancestors. + """ (value,) = values From 5d9da436833f52745800f0d5c43a568a00855351 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 4 Jul 2023 14:55:05 +0530 Subject: [PATCH 10/18] Final max commit --- pymc/logprob/order.py | 100 +----------------------------------------- 1 file changed, 1 insertion(+), 99 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 35a768f4af..7a5a8add26 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -111,105 +111,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens @_logprob.register(MeasurableMax) def max_logprob(op, values, base_rv, **kwargs): - r"""Compute the log-likelihood graph for the `Max` operation. - - Parameters - ---------- - op : Max-Op - values : tensor_like - rv : TensorVariable - - Returns - ------- - logprob : TensorVariable - - Examples - -------- - It is often desirable to find the log-probability of the maximum of i.i.d. random variables. - - The "max of i.i.d. random variables" refers to finding the maximum value among a collection of random variables that are independent and identically distributed. - The example below illustrates how to find the Maximum from the distribution of random variables. - - .. code-block:: python - - import pytensor.tensor as pt - - x = pt.random.normal(0, 1, size=(3,)) - x.name = "x" - print(x.eval()) - #[0.61748772 1.08723759 0.98970957] - - x_max = pt.max(x, axis=None) - print(x_max.eval()) - # 1.087237592696084 - - The log-probability of the maximum of i.i.d. random variables is a measure of the likelihood of observing a specific maximum value in a set of independent and identically distributed random variables. - - The formula that we use here is : - \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(F(x)) + \ln(f(x)) - where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively. - - An example corresponding to this is illustrated below: - - .. code-block:: python - - import pytensor.tensor as pt - from pymc import logp - - x = pt.random.uniform(0, 1, size=(3,)) - x.name = "x" - # [0.09081509 0.84761712 0.59030273] - - x_max = pt.max(x, axis=-1) - # 0.8476171198716373 - - x_max_value = pt.scalar("x_max_value") - x_max_logprob = logp(x_max, x_max_value) - test_value = x_max.eval() - - x_max_logprob.eval({x_max_value: test_value}) - # 0.7679597791946853 - - Currently our implementation has certain limitations which are mandated through some constraints. - - We only consider a distribution of RandomVariables and the logp function fails for NonRVs. - - .. code-block:: python - - import pytensor.tensor as pt - from pymc import logp - - x = pt.exp(pt.random.beta(0, 1, size=(3,))) - x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") - x_max_logprob = logp(x_max, x_max_value) - - The above code gives a Runtime error stating logprob method was not implemented as x in this case is not a pure random variable. - A pure random variable in PyMC represents an unknown quantity in a Bayesian model and is associated with a prior distribution that is combined with the likelihood of observed data to obtain the posterior distribution through Bayesian inference - - We assume only univariate distributions as for multivariate variables, the concept of ordering is ambiguous since a "depth function" is required . - - We only consider independent and identically distributed random variables, for now. - In probability theory and statistics, a collection of random variables is independent and identically distributed if each random variable has the same probability distribution as the others and all are mutually independent. - - .. code-block:: python - - import pytensor.tensor as pt - from pymc import logp - - x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) - x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") - x_max_logprob = logp(x_max, x_max_value) - - The above code gives a Runtime error stating logprob method was not implemented as x in this case is a Non-iid distribution. - - Note: We assume a very fluid definition of i.i.d. here. We say that an RV belongs to an i.i.d. if that RVs do not have different stochastic ancestors. - - - """ + r"""Compute the log-likelihood graph for the `Max` operation.""" (value,) = values logprob = _logprob_helper(base_rv, value) From 698d818c91f4ebc56eadbf8b7a060f1ae12d6b47 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 4 Jul 2023 17:54:26 +0530 Subject: [PATCH 11/18] Deriving Logprob for max --- docs/source/api/logprob.rst | 8 ------- docs/source/api/order_stats.rst | 42 --------------------------------- tests/logprob/test_order.py | 21 ++++++++--------- 3 files changed, 10 insertions(+), 61 deletions(-) delete mode 100644 docs/source/api/order_stats.rst diff --git a/docs/source/api/logprob.rst b/docs/source/api/logprob.rst index e4eda87f5f..d337c65c61 100644 --- a/docs/source/api/logprob.rst +++ b/docs/source/api/logprob.rst @@ -21,11 +21,3 @@ Conditional probability conditional_logp transformed_conditional_logp - -Logarithmic probability ------------------------ - -.. toctree:: - :maxdepth: 2 - - order_stats diff --git a/docs/source/api/order_stats.rst b/docs/source/api/order_stats.rst deleted file mode 100644 index fdd89bc299..0000000000 --- a/docs/source/api/order_stats.rst +++ /dev/null @@ -1,42 +0,0 @@ -================ -Order_Statistics -================ - ------------- -Introduction ------------- - -In statistics, the kth order statistic of a statistical sample is equal to its kth-smallest value. -In this section, we'll tackle how users can find the Logarithmic probability corresponding to the nth order statistic (maximum value) using PyMC for their own Custom distributions. - -In PyMC users can derive their own custom distributions. Custom distribution refers to the ability to define and use probability distributions that are not included in the standard set of distributions provided. -While PyMC provides a wide range of common probability distributions (e.g., Normal, Bernoulli, etc.), there may be cases where you need to use a distribution that is not available by default. In such cases, you can create your own custom distribution using the pm.DensityDist class provided by PyMC. -Simplest way to define a Custom Distribution can be better understood from the following example: - -.. code-block:: python - - import numpy as np - import pymc as pm - from pytensor.tensor import TensorVariable - - def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: - return -(value - mu)**2 - - with pm.Model(): - mu = pm.Normal('mu',0,1) - pm.CustomDist( - 'custom_dist', - mu, - logp=logp, - observed=np.random.randn(100), - ) - idata = pm.sample(100) - -Here, we create a CustomDist that wraps a black-box logp function. This variable cannot be used in prior or posterior predictive sampling because no random function was provided. - ------------------------- -`Max` ------------------------- -Using PyMC and Pytensor, users can extract the maximum of a distribution and derive the log-probablity corresponding to this operation. - -.. autofunction:: pymc.logprob.order.max_logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 81e62d7626..2986d8043f 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -39,7 +39,6 @@ import numpy as np import pytensor.tensor as pt import pytest -import scipy.stats as st import pymc as pm @@ -58,11 +57,16 @@ def test_max(): assert_no_rvs(x_max_logprob) -def test_axis_max(): +@pytest.mark.parametrize( + "x, axis", + [ + (pt.random.normal(0, 1), -1), + (pt.random.normal(0, 1), None), + ], +) +def test_axis_max(x, axis): """Test whether the rewrite takes into account ```None``` axis""" - x = pt.random.normal(0, 1) - x.name = "x" - x_max = pt.max(x, axis=None) + x_max = pt.max(x, axis) x_max_value = pt.vector("x_max_value") x_max_logprob = logp(x_max, x_max_value) @@ -114,16 +118,11 @@ def test_non_supp_axis_max(): """Test whether the logprob for ```pt.max``` for unsupported axis is rejected correctly""" x = pt.random.normal(0, 1, size=(3, 3)) x.name = "x" - x_max_none = pt.max(x, axis=None) - x_max_value = pt.vector("x_max_value") - x_max_logprob = logp(x_max_none, x_max_value) - x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_max, x_max_value) - assert_no_rvs(x_max_logprob) - def test_max_logprob(): """Test whether the logprob for ```pt.max``` produces the corrected From b16a0f3a17bfd77079c58b924f6d232e7f6f81b4 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 9 Jul 2023 23:28:02 +0530 Subject: [PATCH 12/18] Include suggestions --- pymc/logprob/order.py | 13 +++++------ tests/logprob/test_order.py | 44 ++++++++++++++++++------------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 7a5a8add26..210644fbe3 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -75,18 +75,17 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if base_var.owner is None: return None - # NonRVS must be rejected + # Non-RVS must be rejected if not isinstance(base_var.owner.op, RandomVariable): return None # univariate i.i.d. test which also rules out other distributions - if isinstance(base_var.owner.op, RandomVariable): - for params in base_var.owner.inputs[3:]: - if params.type.ndim != 0: - return None + for params in base_var.owner.inputs[3:]: + if params.type.ndim != 0: + return None - if not rv_map_feature.request_measurable(node.inputs): - return None + # if not rv_map_feature.request_measurable(node.inputs): + # return None # Check whether axis is supported or not axis = set(node.op.axis) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 2986d8043f..9f3bc89a93 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -43,20 +43,10 @@ import pymc as pm from pymc import logp +from pymc.logprob import conditional_logp from pymc.testing import assert_no_rvs -def test_max(): - """Test whether the logprob for ```pt.max``` is implemented""" - x = pt.random.normal(0, 1, size=(3,)) - x.name = "x" - x_max = pt.max(x, axis=-1) - x_max_value = pt.vector("x_max_value") - x_max_logprob = logp(x_max, x_max_value) - - assert_no_rvs(x_max_logprob) - - @pytest.mark.parametrize( "x, axis", [ @@ -70,11 +60,9 @@ def test_axis_max(x, axis): x_max_value = pt.vector("x_max_value") x_max_logprob = logp(x_max, x_max_value) - assert_no_rvs(x_max_logprob) - def test_argmax(): - """Test whether the logprob for ```pt.argmax``` is rejected correctly""" + """Test whether the logprob for ```pt.argmax``` is correctly rejected""" x = pt.random.normal(0, 1, size=(3,)) x.name = "x" x_max = pt.argmax(x, axis=-1) @@ -85,7 +73,7 @@ def test_argmax(): def test_max_non_iid_fails(): - """Test whether the logprob for ```pt.max``` for non i.i.d is rejected correctly""" + """Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected""" x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) x.name = "x" x_max = pt.max(x, axis=-1) @@ -95,7 +83,7 @@ def test_max_non_iid_fails(): def test_max_non_rv_fails(): - """Test whether the logprob for ```pt.max``` for non RVs is rejected correctly""" + """Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected""" x = pt.exp(pt.random.beta(0, 1, size=(3,))) x.name = "x" x_max = pt.max(x, axis=-1) @@ -105,7 +93,7 @@ def test_max_non_rv_fails(): def test_max_categorical(): - """Test whether the logprob for ```pt.max``` for unsupported distributions is rejected correctly""" + """Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected""" x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) x.name = "x" x_max = pt.max(x, axis=-1) @@ -115,7 +103,7 @@ def test_max_categorical(): def test_non_supp_axis_max(): - """Test whether the logprob for ```pt.max``` for unsupported axis is rejected correctly""" + """Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected""" x = pt.random.normal(0, 1, size=(3, 3)) x.name = "x" x_max = pt.max(x, axis=-1) @@ -124,22 +112,34 @@ def test_non_supp_axis_max(): x_max_logprob = logp(x_max, x_max_value) -def test_max_logprob(): +@pytest.mark.parametrize( + "n, value", + [ + (3, 0.85), + (3, 0.01), + (2, 0.2), + (4, 0.5), + (11, 0.9), # interestingly this fails + ], +) +def test_max_logprob(n, value): """Test whether the logprob for ```pt.max``` produces the corrected The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) for all 1<=k<=n """ - x = pt.random.uniform(0, 1, size=(3,)) + x = pt.random.uniform(0, 1, size=(n,)) x.name = "x" x_max = pt.max(x, axis=-1) x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) - test_value = 0.85 # or a vector of your choice + assert_no_rvs(x_max_logprob) + + test_value = value - beta_rv = pt.random.beta(3, 1, name="beta") + beta_rv = pt.random.beta(n, 1, name="beta") beta_vv = beta_rv.clone() beta_rv_logprob = logp(beta_rv, beta_vv) From 7332b678d5f87068941cc4a0359a365fcd82abe4 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 10 Jul 2023 20:11:14 +0530 Subject: [PATCH 13/18] Logprob for max --- pymc/logprob/order.py | 12 ++++++++---- tests/logprob/test_order.py | 38 ++++++++++++++----------------------- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 210644fbe3..97cd799fb5 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -75,19 +75,23 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if base_var.owner is None: return None + if not rv_map_feature.request_measurable(node.inputs): + return None + # Non-RVS must be rejected if not isinstance(base_var.owner.op, RandomVariable): return None + # TODO: We are currently only supporting continuous rvs + if base_var.owner.op.dtype.startswith("int"): + return None + # univariate i.i.d. test which also rules out other distributions for params in base_var.owner.inputs[3:]: if params.type.ndim != 0: return None - # if not rv_map_feature.request_measurable(node.inputs): - # return None - - # Check whether axis is supported or not + # Check whether axis covers all dimensions axis = set(node.op.axis) base_var_dims = set(range(base_var.ndim)) if axis != base_var_dims: diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 9f3bc89a93..869f113f9e 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -47,20 +47,6 @@ from pymc.testing import assert_no_rvs -@pytest.mark.parametrize( - "x, axis", - [ - (pt.random.normal(0, 1), -1), - (pt.random.normal(0, 1), None), - ], -) -def test_axis_max(x, axis): - """Test whether the rewrite takes into account ```None``` axis""" - x_max = pt.max(x, axis) - x_max_value = pt.vector("x_max_value") - x_max_logprob = logp(x_max, x_max_value) - - def test_argmax(): """Test whether the logprob for ```pt.argmax``` is correctly rejected""" x = pt.random.normal(0, 1, size=(3,)) @@ -113,25 +99,26 @@ def test_non_supp_axis_max(): @pytest.mark.parametrize( - "n, value", + "shape, value, axis", [ - (3, 0.85), - (3, 0.01), - (2, 0.2), - (4, 0.5), - (11, 0.9), # interestingly this fails + (3, 0.85, -1), + (3, 0.01, 0), + (2, 0.2, None), + (4, 0.5, 0), + ((3, 4), 0.9, None), + ((3, 4), 0.75, (1, 0)), ], ) -def test_max_logprob(n, value): +def test_max_logprob(shape, value, axis): """Test whether the logprob for ```pt.max``` produces the corrected The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) for all 1<=k<=n """ - x = pt.random.uniform(0, 1, size=(n,)) + x = pt.random.uniform(0, 1, size=shape) x.name = "x" - x_max = pt.max(x, axis=-1) + x_max = pt.max(x, axis=axis) x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) @@ -139,10 +126,13 @@ def test_max_logprob(n, value): test_value = value + n = np.prod(shape) beta_rv = pt.random.beta(n, 1, name="beta") beta_vv = beta_rv.clone() beta_rv_logprob = logp(beta_rv, beta_vv) np.testing.assert_allclose( - beta_rv_logprob.eval({beta_vv: test_value}), (x_max_logprob.eval({x_max_value: test_value})) + beta_rv_logprob.eval({beta_vv: test_value}), + (x_max_logprob.eval({x_max_value: test_value})), + rtol=1e-06, ) From f185c8a43699d1883bc600510db799e96d9c73c0 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 15 Jul 2023 16:13:57 +0530 Subject: [PATCH 14/18] Derive logprob for max --- pymc/logprob/order.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 97cd799fb5..03f99d7e62 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -78,8 +78,11 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not rv_map_feature.request_measurable(node.inputs): return None - # Non-RVS must be rejected - if not isinstance(base_var.owner.op, RandomVariable): + # Non-univariate non-RVs must be rejected + if ( + not isinstance(base_var.owner.op, RandomVariable) + and base_var.owner.inputs[0].owner.op.ndim_supp == 0 + ): return None # TODO: We are currently only supporting continuous rvs From 356340b2e4f08c631b7314ddccb385af84b66382 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 16 Jul 2023 04:51:28 +0530 Subject: [PATCH 15/18] Deriving logprob for max --- pymc/logprob/order.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 03f99d7e62..b5ee5c3ca6 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -81,12 +81,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens # Non-univariate non-RVs must be rejected if ( not isinstance(base_var.owner.op, RandomVariable) - and base_var.owner.inputs[0].owner.op.ndim_supp == 0 + and base_var.owner.inputs[0].owner.op.ndim_supp != 0 ): return None # TODO: We are currently only supporting continuous rvs - if base_var.owner.op.dtype.startswith("int"): + if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): return None # univariate i.i.d. test which also rules out other distributions From 816fe4ad2ac1e1253ca2e4c77bc61a62c7ca96c5 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 17 Jul 2023 19:24:48 +0530 Subject: [PATCH 16/18] Reject multivarate and nonrvs for logp of max --- pymc/logprob/order.py | 7 ++----- tests/logprob/test_order.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index b5ee5c3ca6..4033bf674c 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -78,11 +78,8 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not rv_map_feature.request_measurable(node.inputs): return None - # Non-univariate non-RVs must be rejected - if ( - not isinstance(base_var.owner.op, RandomVariable) - and base_var.owner.inputs[0].owner.op.ndim_supp != 0 - ): + # Non-univariate distributions and non-RVs must be rejected + if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0): return None # TODO: We are currently only supporting continuous rvs diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 869f113f9e..5a3818716d 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -78,6 +78,17 @@ def test_max_non_rv_fails(): x_max_logprob = logp(x_max, x_max_value) +def test_max_multivariate_rv_fails(): + _alpha = pt.scalar() + _k = pt.iscalar() + x = pm.StickBreakingWeights.dist(_alpha, _k) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + def test_max_categorical(): """Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected""" x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) From aebe469475523f6c55a03dcf037c101f1d30b8fb Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 4 Jul 2023 14:46:32 +0530 Subject: [PATCH 17/18] Guidelines --- tests/logprob/test_checks.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index 7356bd0bb1..a4e72cda61 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -33,7 +33,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import re import numpy as np import pytensor @@ -44,7 +43,7 @@ from scipy import stats from pymc.distributions import Dirichlet -from pymc.logprob.basic import conditional_logp +from pymc.logprob.joint_logprob import factorized_joint_logprob from tests.distributions.test_multivariate import dirichlet_logpdf @@ -59,7 +58,7 @@ def test_specify_shape_logprob(): # 2. Request logp x_vv = x_rv.clone() - [x_logp] = conditional_logp({x_rv: x_vv}).values() + [x_logp] = factorized_joint_logprob({x_rv: x_vv}).values() # 3. Test logp x_logp_fn = pytensor.function([last_dim, x_vv], x_logp) @@ -81,19 +80,17 @@ def test_assert_logprob(): rv = pt.random.normal() assert_op = Assert("Test assert") # Example: Add assert that rv must be positive - assert_rv = assert_op(rv, rv > 0) + assert_rv = assert_op(rv > 0, rv) assert_rv.name = "assert_rv" assert_vv = assert_rv.clone() - assert_logp = conditional_logp({assert_rv: assert_vv})[assert_vv] + assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv] # Check valid value is correct and doesn't raise # Since here the value to the rv satisfies the condition, no error is raised. valid_value = 3.0 - np.testing.assert_allclose( - assert_logp.eval({assert_vv: valid_value}), - stats.norm.logpdf(valid_value), - ) + with pytest.raises(AssertionError, match="Test assert"): + assert_logp.eval({assert_vv: valid_value}) # Check invalid value # Since here the value to the rv is negative, an exception is raised as the condition is not met From 8bf01546cf41240b1d4e6e3ab79ba7ec99c8aa2d Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 26 Jul 2023 21:27:17 +0530 Subject: [PATCH 18/18] logprob for maximum derived --- tests/logprob/test_checks.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index a4e72cda61..7356bd0bb1 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -33,6 +33,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import re import numpy as np import pytensor @@ -43,7 +44,7 @@ from scipy import stats from pymc.distributions import Dirichlet -from pymc.logprob.joint_logprob import factorized_joint_logprob +from pymc.logprob.basic import conditional_logp from tests.distributions.test_multivariate import dirichlet_logpdf @@ -58,7 +59,7 @@ def test_specify_shape_logprob(): # 2. Request logp x_vv = x_rv.clone() - [x_logp] = factorized_joint_logprob({x_rv: x_vv}).values() + [x_logp] = conditional_logp({x_rv: x_vv}).values() # 3. Test logp x_logp_fn = pytensor.function([last_dim, x_vv], x_logp) @@ -80,17 +81,19 @@ def test_assert_logprob(): rv = pt.random.normal() assert_op = Assert("Test assert") # Example: Add assert that rv must be positive - assert_rv = assert_op(rv > 0, rv) + assert_rv = assert_op(rv, rv > 0) assert_rv.name = "assert_rv" assert_vv = assert_rv.clone() - assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv] + assert_logp = conditional_logp({assert_rv: assert_vv})[assert_vv] # Check valid value is correct and doesn't raise # Since here the value to the rv satisfies the condition, no error is raised. valid_value = 3.0 - with pytest.raises(AssertionError, match="Test assert"): - assert_logp.eval({assert_vv: valid_value}) + np.testing.assert_allclose( + assert_logp.eval({assert_vv: valid_value}), + stats.norm.logpdf(valid_value), + ) # Check invalid value # Since here the value to the rv is negative, an exception is raised as the condition is not met