Skip to content

Commit e3961fc

Browse files
Logprob derivation for Max (#6769)
1 parent 510d3b8 commit e3961fc

File tree

6 files changed

+282
-0
lines changed

6 files changed

+282
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ jobs:
110110
tests/logprob/test_composite_logprob.py
111111
tests/logprob/test_cumsum.py
112112
tests/logprob/test_mixture.py
113+
tests/logprob/test_order.py
113114
tests/logprob/test_rewriting.py
114115
tests/logprob/test_scan.py
115116
tests/logprob/test_tensor.py

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import pymc.logprob.cumsum
5050
import pymc.logprob.checks
5151
import pymc.logprob.mixture
52+
import pymc.logprob.order
5253
import pymc.logprob.scan
5354
import pymc.logprob.tensor
5455
import pymc.logprob.transforms

pymc/logprob/order.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
37+
from typing import List, Optional
38+
39+
import pytensor.tensor as pt
40+
41+
from pytensor.graph.basic import Node
42+
from pytensor.graph.fg import FunctionGraph
43+
from pytensor.graph.rewriting.basic import node_rewriter
44+
from pytensor.tensor.math import Max
45+
from pytensor.tensor.random.op import RandomVariable
46+
from pytensor.tensor.var import TensorVariable
47+
48+
from pymc.logprob.abstract import (
49+
MeasurableVariable,
50+
_logcdf_helper,
51+
_logprob,
52+
_logprob_helper,
53+
)
54+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
55+
56+
57+
class MeasurableMax(Max):
58+
"""A placeholder used to specify a log-likelihood for a max sub-graph."""
59+
60+
61+
MeasurableVariable.register(MeasurableMax)
62+
63+
64+
@node_rewriter([Max])
65+
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
66+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
67+
if rv_map_feature is None:
68+
return None # pragma: no cover
69+
70+
if isinstance(node.op, MeasurableMax):
71+
return None # pragma: no cover
72+
73+
base_var = node.inputs[0]
74+
75+
if base_var.owner is None:
76+
return None
77+
78+
if not rv_map_feature.request_measurable(node.inputs):
79+
return None
80+
81+
# Non-univariate distributions and non-RVs must be rejected
82+
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
83+
return None
84+
85+
# TODO: We are currently only supporting continuous rvs
86+
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
87+
return None
88+
89+
# univariate i.i.d. test which also rules out other distributions
90+
for params in base_var.owner.inputs[3:]:
91+
if params.type.ndim != 0:
92+
return None
93+
94+
# Check whether axis covers all dimensions
95+
axis = set(node.op.axis)
96+
base_var_dims = set(range(base_var.ndim))
97+
if axis != base_var_dims:
98+
return None
99+
100+
measurable_max = MeasurableMax(list(axis))
101+
max_rv_node = measurable_max.make_node(base_var)
102+
max_rv = max_rv_node.outputs
103+
104+
return max_rv
105+
106+
107+
measurable_ir_rewrites_db.register(
108+
"find_measurable_max",
109+
find_measurable_max,
110+
"basic",
111+
"max",
112+
)
113+
114+
115+
@_logprob.register(MeasurableMax)
116+
def max_logprob(op, values, base_rv, **kwargs):
117+
r"""Compute the log-likelihood graph for the `Max` operation."""
118+
(value,) = values
119+
120+
logprob = _logprob_helper(base_rv, value)
121+
logcdf = _logcdf_helper(base_rv, value)
122+
123+
n = base_rv.size
124+
125+
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)
126+
127+
return logprob

pymc/logprob/rewriting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
EquilibriumGraphRewriter,
5858
GraphRewriter,
5959
node_rewriter,
60+
out2in,
6061
)
6162
from pytensor.graph.rewriting.db import (
6263
LocalGroupDB,
@@ -70,6 +71,7 @@
7071
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
7172
from pytensor.tensor.rewriting.basic import register_canonicalize
7273
from pytensor.tensor.rewriting.shape import ShapeFeature
74+
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
7375
from pytensor.tensor.subtensor import (
7476
AdvancedIncSubtensor,
7577
AdvancedIncSubtensor1,
@@ -358,6 +360,7 @@ def incsubtensor_rv_replace(fgraph, node):
358360
logprob_rewrites_db = SequenceDB()
359361
logprob_rewrites_db.name = "logprob_rewrites_db"
360362
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
363+
logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic")
361364

362365
# These rewrites convert un-measurable variables into their measurable forms,
363366
# but they need to be reapplied, because some of the measurable forms require

scripts/run_mypy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
pymc/logprob/censoring.py
3434
pymc/logprob/basic.py
3535
pymc/logprob/mixture.py
36+
pymc/logprob/order.py
3637
pymc/logprob/rewriting.py
3738
pymc/logprob/scan.py
3839
pymc/logprob/tensor.py

tests/logprob/test_order.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
37+
import re
38+
39+
import numpy as np
40+
import pytensor.tensor as pt
41+
import pytest
42+
43+
import pymc as pm
44+
45+
from pymc import logp
46+
from pymc.logprob import conditional_logp
47+
from pymc.testing import assert_no_rvs
48+
49+
50+
def test_argmax():
51+
"""Test whether the logprob for ```pt.argmax``` is correctly rejected"""
52+
x = pt.random.normal(0, 1, size=(3,))
53+
x.name = "x"
54+
x_max = pt.argmax(x, axis=-1)
55+
x_max_value = pt.vector("x_max_value")
56+
57+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")):
58+
x_max_logprob = logp(x_max, x_max_value)
59+
60+
61+
def test_max_non_iid_fails():
62+
"""Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
63+
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
64+
x.name = "x"
65+
x_max = pt.max(x, axis=-1)
66+
x_max_value = pt.vector("x_max_value")
67+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
68+
x_max_logprob = logp(x_max, x_max_value)
69+
70+
71+
def test_max_non_rv_fails():
72+
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
73+
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
74+
x.name = "x"
75+
x_max = pt.max(x, axis=-1)
76+
x_max_value = pt.vector("x_max_value")
77+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
78+
x_max_logprob = logp(x_max, x_max_value)
79+
80+
81+
def test_max_multivariate_rv_fails():
82+
_alpha = pt.scalar()
83+
_k = pt.iscalar()
84+
x = pm.StickBreakingWeights.dist(_alpha, _k)
85+
x.name = "x"
86+
x_max = pt.max(x, axis=-1)
87+
x_max_value = pt.vector("x_max_value")
88+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
89+
x_max_logprob = logp(x_max, x_max_value)
90+
91+
92+
def test_max_categorical():
93+
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
94+
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
95+
x.name = "x"
96+
x_max = pt.max(x, axis=-1)
97+
x_max_value = pt.vector("x_max_value")
98+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
99+
x_max_logprob = logp(x_max, x_max_value)
100+
101+
102+
def test_non_supp_axis_max():
103+
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
104+
x = pt.random.normal(0, 1, size=(3, 3))
105+
x.name = "x"
106+
x_max = pt.max(x, axis=-1)
107+
x_max_value = pt.vector("x_max_value")
108+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
109+
x_max_logprob = logp(x_max, x_max_value)
110+
111+
112+
@pytest.mark.parametrize(
113+
"shape, value, axis",
114+
[
115+
(3, 0.85, -1),
116+
(3, 0.01, 0),
117+
(2, 0.2, None),
118+
(4, 0.5, 0),
119+
((3, 4), 0.9, None),
120+
((3, 4), 0.75, (1, 0)),
121+
],
122+
)
123+
def test_max_logprob(shape, value, axis):
124+
"""Test whether the logprob for ```pt.max``` produces the corrected
125+
126+
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
127+
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
128+
for all 1<=k<=n
129+
"""
130+
x = pt.random.uniform(0, 1, size=shape)
131+
x.name = "x"
132+
x_max = pt.max(x, axis=axis)
133+
x_max_value = pt.scalar("x_max_value")
134+
x_max_logprob = logp(x_max, x_max_value)
135+
136+
assert_no_rvs(x_max_logprob)
137+
138+
test_value = value
139+
140+
n = np.prod(shape)
141+
beta_rv = pt.random.beta(n, 1, name="beta")
142+
beta_vv = beta_rv.clone()
143+
beta_rv_logprob = logp(beta_rv, beta_vv)
144+
145+
np.testing.assert_allclose(
146+
beta_rv_logprob.eval({beta_vv: test_value}),
147+
(x_max_logprob.eval({x_max_value: test_value})),
148+
rtol=1e-06,
149+
)

0 commit comments

Comments
 (0)