Skip to content

Commit b88fa7f

Browse files
Fixes in logprob derivation of min
1 parent 470d474 commit b88fa7f

File tree

2 files changed

+219
-23
lines changed

2 files changed

+219
-23
lines changed

pymc/logprob/order.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from pytensor.graph.basic import Node
4242
from pytensor.graph.fg import FunctionGraph
4343
from pytensor.graph.rewriting.basic import node_rewriter
44+
from pytensor.scalar.basic import Mul
45+
from pytensor.tensor.elemwise import Elemwise
4446
from pytensor.tensor.math import Max
4547
from pytensor.tensor.random.op import RandomVariable
4648
from pytensor.tensor.var import TensorVariable
@@ -122,7 +124,94 @@ def max_logprob(op, values, base_rv, **kwargs):
122124
logcdf = _logcdf_helper(base_rv, value)
123125

124126
[n] = constant_fold([base_rv.size])
125-
126127
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)
127128

128129
return logprob
130+
131+
132+
class MeasurableMin(Max):
133+
"""A placeholder used to specify a log-likelihood for a min sub-graph."""
134+
135+
136+
MeasurableVariable.register(MeasurableMin)
137+
138+
139+
@node_rewriter(tracks=[Max])
140+
def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
141+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
142+
143+
if rv_map_feature is None:
144+
return None # pragma: no cover
145+
146+
if isinstance(node.op, MeasurableMin):
147+
return None # pragma: no cover
148+
149+
base_var = node.inputs[0]
150+
151+
if base_var.owner is None:
152+
return None
153+
154+
if not rv_map_feature.request_measurable(node.inputs):
155+
return None
156+
157+
# Min is the Max of the negation of the same distribution. Hence, op must be Elemiwise
158+
if not isinstance(base_var.owner.op, Elemwise):
159+
return None
160+
161+
# negation is -1*(rv). Hence the scalar_op must be Mul
162+
if not isinstance(base_var.owner.op.scalar_op, Mul):
163+
return None
164+
165+
base_rv = base_var.owner.inputs[0]
166+
167+
# Non-univariate distributions and non-RVs must be rejected
168+
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
169+
return None
170+
171+
# TODO: We are currently only supporting continuous rvs
172+
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
173+
return None
174+
175+
# univariate i.i.d. test which also rules out other distributions
176+
for params in base_rv.owner.inputs[3:]:
177+
if params.type.ndim != 0:
178+
return None
179+
180+
# Check whether axis is supported or not
181+
axis = set(node.op.axis)
182+
base_var_dims = set(range(base_var.ndim))
183+
if axis != base_var_dims:
184+
return None
185+
186+
measurable_min = MeasurableMin(list(axis))
187+
min_rv_node = measurable_min.make_node(base_var)
188+
min_rv = min_rv_node.outputs
189+
190+
return min_rv
191+
192+
193+
measurable_ir_rewrites_db.register(
194+
"find_measurable_min",
195+
find_measurable_min,
196+
"basic",
197+
"min",
198+
)
199+
200+
201+
@_logprob.register(MeasurableMin)
202+
def min_logprob(op, values, base_var, **kwargs):
203+
r"""Compute the log-likelihood graph for the `Max` operation.
204+
The formula that we use here is :
205+
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
206+
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
207+
"""
208+
(value,) = values
209+
base_rv = base_var.owner.inputs[0]
210+
211+
logprob = _logprob_helper(base_rv, value)
212+
logcdf = _logcdf_helper(base_rv, value)
213+
214+
[n] = constant_fold([base_rv.size])
215+
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)
216+
217+
return logprob

tests/logprob/test_order.py

Lines changed: 129 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import pymc as pm
4444

4545
from pymc import logp
46-
from pymc.logprob import conditional_logp
4746
from pymc.testing import assert_no_rvs
4847

4948

@@ -58,55 +57,112 @@ def test_argmax():
5857
x_max_logprob = logp(x_max, x_max_value)
5958

6059

61-
def test_max_non_iid_fails():
62-
"""Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
60+
@pytest.mark.parametrize(
61+
"if_max",
62+
[
63+
True,
64+
False,
65+
],
66+
)
67+
def test_non_iid_fails(if_max):
68+
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
6369
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
6470
x.name = "x"
65-
x_max = pt.max(x, axis=-1)
66-
x_max_value = pt.vector("x_max_value")
71+
if if_max == True:
72+
x_m = pt.max(x, axis=-1)
73+
x_m_value = pt.vector("x_max_value")
74+
else:
75+
x_min = pt.min(x, axis=-1)
76+
x_m = x_min.owner.inputs[0]
77+
x_m_value = pt.vector("x_min_value")
6778
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
68-
x_max_logprob = logp(x_max, x_max_value)
79+
x_max_logprob = logp(x_m, x_m_value)
6980

7081

71-
def test_max_non_rv_fails():
82+
@pytest.mark.parametrize(
83+
"if_max",
84+
[True, False],
85+
)
86+
def test_non_rv_fails(if_max):
7287
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
7388
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
7489
x.name = "x"
75-
x_max = pt.max(x, axis=-1)
76-
x_max_value = pt.vector("x_max_value")
90+
if if_max == True:
91+
x_m = pt.max(x, axis=-1)
92+
x_m_value = pt.vector("x_max_value")
93+
else:
94+
x_min = pt.min(x, axis=-1)
95+
x_m = x_min.owner.inputs[0]
96+
x_m_value = pt.vector("x_min_value")
7797
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
78-
x_max_logprob = logp(x_max, x_max_value)
98+
x_max_logprob = logp(x_m, x_m_value)
7999

80100

81-
def test_max_multivariate_rv_fails():
101+
@pytest.mark.parametrize(
102+
"if_max",
103+
[
104+
True,
105+
False,
106+
],
107+
)
108+
def test_multivariate_rv_fails(if_max):
82109
_alpha = pt.scalar()
83110
_k = pt.iscalar()
84111
x = pm.StickBreakingWeights.dist(_alpha, _k)
85112
x.name = "x"
86-
x_max = pt.max(x, axis=-1)
87-
x_max_value = pt.vector("x_max_value")
113+
if if_max == True:
114+
x_m = pt.max(x, axis=-1)
115+
x_m_value = pt.vector("x_max_value")
116+
else:
117+
x_min = pt.min(x, axis=-1)
118+
x_m = x_min.owner.inputs[0]
119+
x_m_value = pt.vector("x_min_value")
88120
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
89-
x_max_logprob = logp(x_max, x_max_value)
121+
x_max_logprob = logp(x_m, x_m_value)
90122

91123

92-
def test_max_categorical():
124+
@pytest.mark.parametrize(
125+
"if_max",
126+
[
127+
True,
128+
False,
129+
],
130+
)
131+
def test_categorical(if_max):
93132
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
94133
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
95134
x.name = "x"
96-
x_max = pt.max(x, axis=-1)
97-
x_max_value = pt.vector("x_max_value")
135+
if if_max == True:
136+
x_m = pt.max(x, axis=-1)
137+
x_m_value = pt.vector("x_max_value")
138+
else:
139+
x_min = pt.min(x, axis=-1)
140+
x_m = x_min.owner.inputs[0]
141+
x_m_value = pt.vector("x_min_value")
98142
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
99-
x_max_logprob = logp(x_max, x_max_value)
143+
x_max_logprob = logp(x_m, x_m_value)
100144

101145

102-
def test_non_supp_axis_max():
146+
@pytest.mark.parametrize(
147+
"if_max",
148+
[
149+
True,
150+
False,
151+
],
152+
)
153+
def test_non_supp_axis(if_max):
103154
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
104155
x = pt.random.normal(0, 1, size=(3, 3))
105156
x.name = "x"
106-
x_max = pt.max(x, axis=-1)
107-
x_max_value = pt.vector("x_max_value")
157+
if if_max == True:
158+
x_m = pt.max(x, axis=-1)
159+
x_m_value = pt.vector("x_max_value")
160+
else:
161+
x_min = pt.min(x, axis=-1)
162+
x_m = x_min.owner.inputs[0]
163+
x_m_value = pt.vector("x_min_value")
108164
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
109-
x_max_logprob = logp(x_max, x_max_value)
165+
x_max_logprob = logp(x_m, x_m_value)
110166

111167

112168
@pytest.mark.parametrize(
@@ -147,3 +203,54 @@ def test_max_logprob(shape, value, axis):
147203
(x_max_logprob.eval({x_max_value: test_value})),
148204
rtol=1e-06,
149205
)
206+
207+
208+
@pytest.mark.parametrize(
209+
"shape, value, axis",
210+
[
211+
(3, 0.85, -1),
212+
(3, 0.01, 0),
213+
(2, 0.2, None),
214+
(4, 0.5, 0),
215+
((3, 4), 0.9, None),
216+
((3, 4), 0.75, (1, 0)),
217+
],
218+
)
219+
def test_min_logprob(shape, value, axis):
220+
"""Test whether the logprob for ```pt.mix``` produces the corrected
221+
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
222+
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
223+
for all 1<=k<=n
224+
"""
225+
x = pt.random.uniform(0, 1, size=shape)
226+
x.name = "x"
227+
x_min = pt.min(x, axis=axis)
228+
x_min_rv = x_min.owner.inputs[0]
229+
x_min_value = pt.scalar("x_min_value")
230+
x_min_logprob = logp(x_min_rv, x_min_value)
231+
232+
assert_no_rvs(x_min_logprob)
233+
234+
test_value = value
235+
236+
n = np.prod(shape)
237+
beta_rv = pt.random.beta(1, n, name="beta")
238+
beta_vv = beta_rv.clone()
239+
beta_rv_logprob = logp(beta_rv, beta_vv)
240+
241+
np.testing.assert_allclose(
242+
beta_rv_logprob.eval({beta_vv: test_value}),
243+
(x_min_logprob.eval({x_min_value: test_value})),
244+
rtol=1e-06,
245+
)
246+
247+
248+
def test_min_non_mul_elemwise_fails():
249+
"""Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly"""
250+
x = pt.log(pt.random.beta(0, 1, size=(3,)))
251+
x.name = "x"
252+
x_min = pt.min(x, axis=-1)
253+
x_min_rv = x_min.owner.inputs[0]
254+
x_min_value = pt.vector("x_min_value")
255+
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
256+
x_min_logprob = logp(x_min_rv, x_min_value)

0 commit comments

Comments
 (0)