Skip to content

Commit 6d2a289

Browse files
Logprob derivation for Min of continuous IID variables (#6846)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent f249f12 commit 6d2a289

File tree

2 files changed

+205
-23
lines changed

2 files changed

+205
-23
lines changed

pymc/logprob/order.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
from pytensor.graph.basic import Node
4242
from pytensor.graph.fg import FunctionGraph
4343
from pytensor.graph.rewriting.basic import node_rewriter
44+
from pytensor.scalar.basic import Mul
45+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
46+
from pytensor.tensor.elemwise import Elemwise
47+
from pytensor.tensor.exceptions import NotScalarConstantError
4448
from pytensor.tensor.math import Max
4549
from pytensor.tensor.random.op import RandomVariable
4650
from pytensor.tensor.var import TensorVariable
@@ -122,7 +126,102 @@ def max_logprob(op, values, base_rv, **kwargs):
122126
logcdf = _logcdf_helper(base_rv, value)
123127

124128
[n] = constant_fold([base_rv.size])
125-
126129
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)
127130

128131
return logprob
132+
133+
134+
class MeasurableMaxNeg(Max):
135+
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
136+
This shows up in the graph of min, which is (neg(max(neg(x)))."""
137+
138+
139+
MeasurableVariable.register(MeasurableMaxNeg)
140+
141+
142+
@node_rewriter(tracks=[Max])
143+
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
144+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
145+
146+
if rv_map_feature is None:
147+
return None # pragma: no cover
148+
149+
if isinstance(node.op, MeasurableMaxNeg):
150+
return None # pragma: no cover
151+
152+
base_var = node.inputs[0]
153+
154+
if base_var.owner is None:
155+
return None
156+
157+
if not rv_map_feature.request_measurable(node.inputs):
158+
return None
159+
160+
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
161+
if not isinstance(base_var.owner.op, Elemwise):
162+
return None
163+
164+
# negation is rv * (-1). Hence the scalar_op must be Mul
165+
try:
166+
if not (
167+
isinstance(base_var.owner.op.scalar_op, Mul)
168+
and len(base_var.owner.inputs) == 2
169+
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
170+
):
171+
return None
172+
except NotScalarConstantError:
173+
return None
174+
175+
base_rv = base_var.owner.inputs[0]
176+
177+
# Non-univariate distributions and non-RVs must be rejected
178+
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
179+
return None
180+
181+
# TODO: We are currently only supporting continuous rvs
182+
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
183+
return None
184+
185+
# univariate i.i.d. test which also rules out other distributions
186+
for params in base_rv.owner.inputs[3:]:
187+
if params.type.ndim != 0:
188+
return None
189+
190+
# Check whether axis is supported or not
191+
axis = set(node.op.axis)
192+
base_var_dims = set(range(base_var.ndim))
193+
if axis != base_var_dims:
194+
return None
195+
196+
measurable_min = MeasurableMaxNeg(list(axis))
197+
min_rv_node = measurable_min.make_node(base_var)
198+
min_rv = min_rv_node.outputs
199+
200+
return min_rv
201+
202+
203+
measurable_ir_rewrites_db.register(
204+
"find_measurable_max_neg",
205+
find_measurable_max_neg,
206+
"basic",
207+
"min",
208+
)
209+
210+
211+
@_logprob.register(MeasurableMaxNeg)
212+
def max_neg_logprob(op, values, base_var, **kwargs):
213+
r"""Compute the log-likelihood graph for the `Max` operation.
214+
The formula that we use here is :
215+
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
216+
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
217+
"""
218+
(value,) = values
219+
base_rv = base_var.owner.inputs[0]
220+
221+
logprob = _logprob_helper(base_rv, -value)
222+
logcdf = _logcdf_helper(base_rv, -value)
223+
224+
[n] = constant_fold([base_rv.size])
225+
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)
226+
227+
return logprob

tests/logprob/test_order.py

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import pymc as pm
4444

4545
from pymc import logp
46-
from pymc.logprob import conditional_logp
4746
from pymc.testing import assert_no_rvs
4847

4948

@@ -58,55 +57,90 @@ def test_argmax():
5857
x_max_logprob = logp(x_max, x_max_value)
5958

6059

61-
def test_max_non_iid_fails():
62-
"""Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
60+
@pytest.mark.parametrize(
61+
"pt_op",
62+
[
63+
pt.max,
64+
pt.min,
65+
],
66+
)
67+
def test_non_iid_fails(pt_op):
68+
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
6369
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
6470
x.name = "x"
65-
x_max = pt.max(x, axis=-1)
66-
x_max_value = pt.vector("x_max_value")
71+
x_m = pt_op(x, axis=-1)
72+
x_m_value = pt.vector("x_value")
6773
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
68-
x_max_logprob = logp(x_max, x_max_value)
74+
x_max_logprob = logp(x_m, x_m_value)
6975

7076

71-
def test_max_non_rv_fails():
77+
@pytest.mark.parametrize(
78+
"pt_op",
79+
[
80+
pt.max,
81+
pt.min,
82+
],
83+
)
84+
def test_non_rv_fails(pt_op):
7285
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
7386
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
7487
x.name = "x"
75-
x_max = pt.max(x, axis=-1)
76-
x_max_value = pt.vector("x_max_value")
88+
x_m = pt_op(x, axis=-1)
89+
x_m_value = pt.vector("x_value")
7790
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
78-
x_max_logprob = logp(x_max, x_max_value)
91+
x_max_logprob = logp(x_m, x_m_value)
7992

8093

81-
def test_max_multivariate_rv_fails():
94+
@pytest.mark.parametrize(
95+
"pt_op",
96+
[
97+
pt.max,
98+
pt.min,
99+
],
100+
)
101+
def test_multivariate_rv_fails(pt_op):
82102
_alpha = pt.scalar()
83103
_k = pt.iscalar()
84104
x = pm.StickBreakingWeights.dist(_alpha, _k)
85105
x.name = "x"
86-
x_max = pt.max(x, axis=-1)
87-
x_max_value = pt.vector("x_max_value")
106+
x_m = pt_op(x, axis=-1)
107+
x_m_value = pt.vector("x_value")
88108
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
89-
x_max_logprob = logp(x_max, x_max_value)
109+
x_max_logprob = logp(x_m, x_m_value)
90110

91111

92-
def test_max_categorical():
112+
@pytest.mark.parametrize(
113+
"pt_op",
114+
[
115+
pt.max,
116+
pt.min,
117+
],
118+
)
119+
def test_categorical(pt_op):
93120
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
94121
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
95122
x.name = "x"
96-
x_max = pt.max(x, axis=-1)
97-
x_max_value = pt.vector("x_max_value")
123+
x_m = pt_op(x, axis=-1)
124+
x_m_value = pt.vector("x_value")
98125
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
99-
x_max_logprob = logp(x_max, x_max_value)
126+
x_max_logprob = logp(x_m, x_m_value)
100127

101128

102-
def test_non_supp_axis_max():
129+
@pytest.mark.parametrize(
130+
"pt_op",
131+
[
132+
pt.max,
133+
pt.min,
134+
],
135+
)
136+
def test_non_supp_axis(pt_op):
103137
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
104138
x = pt.random.normal(0, 1, size=(3, 3))
105139
x.name = "x"
106-
x_max = pt.max(x, axis=-1)
107-
x_max_value = pt.vector("x_max_value")
140+
x_m = pt_op(x, axis=-1)
141+
x_m_value = pt.vector("x_value")
108142
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
109-
x_max_logprob = logp(x_max, x_max_value)
143+
x_max_logprob = logp(x_m, x_m_value)
110144

111145

112146
@pytest.mark.parametrize(
@@ -147,3 +181,52 @@ def test_max_logprob(shape, value, axis):
147181
(x_max_logprob.eval({x_max_value: test_value})),
148182
rtol=1e-06,
149183
)
184+
185+
186+
@pytest.mark.parametrize(
187+
"shape, value, axis",
188+
[
189+
(3, 0.85, -1),
190+
(3, 0.01, 0),
191+
(2, 0.2, None),
192+
(4, 0.5, 0),
193+
((3, 4), 0.9, None),
194+
((3, 4), 0.75, (1, 0)),
195+
],
196+
)
197+
def test_min_logprob(shape, value, axis):
198+
"""Test whether the logprob for ```pt.mix``` produces the corrected
199+
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
200+
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
201+
for all 1<=k<=n
202+
"""
203+
x = pt.random.uniform(0, 1, size=shape)
204+
x.name = "x"
205+
x_min = pt.min(x, axis=axis)
206+
x_min_value = pt.scalar("x_min_value")
207+
x_min_logprob = logp(x_min, x_min_value)
208+
209+
assert_no_rvs(x_min_logprob)
210+
211+
test_value = value
212+
213+
n = np.prod(shape)
214+
beta_rv = pt.random.beta(1, n, name="beta")
215+
beta_vv = beta_rv.clone()
216+
beta_rv_logprob = logp(beta_rv, beta_vv)
217+
218+
np.testing.assert_allclose(
219+
beta_rv_logprob.eval({beta_vv: test_value}),
220+
(x_min_logprob.eval({x_min_value: test_value})),
221+
rtol=1e-06,
222+
)
223+
224+
225+
def test_min_non_mul_elemwise_fails():
226+
"""Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly"""
227+
x = pt.log(pt.random.beta(0, 1, size=(3,)))
228+
x.name = "x"
229+
x_min = pt.min(x, axis=-1)
230+
x_min_value = pt.vector("x_min_value")
231+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
232+
x_min_logprob = logp(x_min, x_min_value)

0 commit comments

Comments
 (0)