Skip to content

Commit fea004c

Browse files
Generalize logp of min to neg(max)
1 parent b88fa7f commit fea004c

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

pymc/logprob/order.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def max_logprob(op, values, base_rv, **kwargs):
129129
return logprob
130130

131131

132-
class MeasurableMin(Max):
132+
class MeasurableMaxNeg(Max):
133133
"""A placeholder used to specify a log-likelihood for a min sub-graph."""
134134

135135

136-
MeasurableVariable.register(MeasurableMin)
136+
MeasurableVariable.register(MeasurableMaxNeg)
137137

138138

139139
@node_rewriter(tracks=[Max])
@@ -143,7 +143,7 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
143143
if rv_map_feature is None:
144144
return None # pragma: no cover
145145

146-
if isinstance(node.op, MeasurableMin):
146+
if isinstance(node.op, MeasurableMaxNeg):
147147
return None # pragma: no cover
148148

149149
base_var = node.inputs[0]
@@ -183,7 +183,7 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
183183
if axis != base_var_dims:
184184
return None
185185

186-
measurable_min = MeasurableMin(list(axis))
186+
measurable_min = MeasurableMaxNeg(list(axis))
187187
min_rv_node = measurable_min.make_node(base_var)
188188
min_rv = min_rv_node.outputs
189189

@@ -198,7 +198,7 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
198198
)
199199

200200

201-
@_logprob.register(MeasurableMin)
201+
@_logprob.register(MeasurableMaxNeg)
202202
def min_logprob(op, values, base_var, **kwargs):
203203
r"""Compute the log-likelihood graph for the `Max` operation.
204204
The formula that we use here is :
@@ -208,8 +208,8 @@ def min_logprob(op, values, base_var, **kwargs):
208208
(value,) = values
209209
base_rv = base_var.owner.inputs[0]
210210

211-
logprob = _logprob_helper(base_rv, value)
212-
logcdf = _logcdf_helper(base_rv, value)
211+
logprob = _logprob_helper(base_rv, -value)
212+
logcdf = _logcdf_helper(base_rv, -value)
213213

214214
[n] = constant_fold([base_rv.size])
215215
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

tests/logprob/test_order.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def test_non_iid_fails(if_max):
7272
x_m = pt.max(x, axis=-1)
7373
x_m_value = pt.vector("x_max_value")
7474
else:
75-
x_min = pt.min(x, axis=-1)
76-
x_m = x_min.owner.inputs[0]
75+
x_m = pt.min(x, axis=-1)
7776
x_m_value = pt.vector("x_min_value")
7877
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
7978
x_max_logprob = logp(x_m, x_m_value)
@@ -91,8 +90,7 @@ def test_non_rv_fails(if_max):
9190
x_m = pt.max(x, axis=-1)
9291
x_m_value = pt.vector("x_max_value")
9392
else:
94-
x_min = pt.min(x, axis=-1)
95-
x_m = x_min.owner.inputs[0]
93+
x_m = pt.min(x, axis=-1)
9694
x_m_value = pt.vector("x_min_value")
9795
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
9896
x_max_logprob = logp(x_m, x_m_value)
@@ -114,8 +112,7 @@ def test_multivariate_rv_fails(if_max):
114112
x_m = pt.max(x, axis=-1)
115113
x_m_value = pt.vector("x_max_value")
116114
else:
117-
x_min = pt.min(x, axis=-1)
118-
x_m = x_min.owner.inputs[0]
115+
x_m = pt.min(x, axis=-1)
119116
x_m_value = pt.vector("x_min_value")
120117
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
121118
x_max_logprob = logp(x_m, x_m_value)
@@ -136,8 +133,7 @@ def test_categorical(if_max):
136133
x_m = pt.max(x, axis=-1)
137134
x_m_value = pt.vector("x_max_value")
138135
else:
139-
x_min = pt.min(x, axis=-1)
140-
x_m = x_min.owner.inputs[0]
136+
x_m = pt.min(x, axis=-1)
141137
x_m_value = pt.vector("x_min_value")
142138
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
143139
x_max_logprob = logp(x_m, x_m_value)
@@ -158,8 +154,7 @@ def test_non_supp_axis(if_max):
158154
x_m = pt.max(x, axis=-1)
159155
x_m_value = pt.vector("x_max_value")
160156
else:
161-
x_min = pt.min(x, axis=-1)
162-
x_m = x_min.owner.inputs[0]
157+
x_m = pt.min(x, axis=-1)
163158
x_m_value = pt.vector("x_min_value")
164159
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
165160
x_max_logprob = logp(x_m, x_m_value)
@@ -225,9 +220,8 @@ def test_min_logprob(shape, value, axis):
225220
x = pt.random.uniform(0, 1, size=shape)
226221
x.name = "x"
227222
x_min = pt.min(x, axis=axis)
228-
x_min_rv = x_min.owner.inputs[0]
229223
x_min_value = pt.scalar("x_min_value")
230-
x_min_logprob = logp(x_min_rv, x_min_value)
224+
x_min_logprob = logp(x_min, x_min_value)
231225

232226
assert_no_rvs(x_min_logprob)
233227

@@ -250,7 +244,6 @@ def test_min_non_mul_elemwise_fails():
250244
x = pt.log(pt.random.beta(0, 1, size=(3,)))
251245
x.name = "x"
252246
x_min = pt.min(x, axis=-1)
253-
x_min_rv = x_min.owner.inputs[0]
254247
x_min_value = pt.vector("x_min_value")
255248
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
256-
x_min_logprob = logp(x_min_rv, x_min_value)
249+
x_min_logprob = logp(x_min, x_min_value)

0 commit comments

Comments
 (0)