Skip to content

Commit 086f7fe

Browse files
Incorporating suggestions
1 parent fea004c commit 086f7fe

File tree

2 files changed

+49
-56
lines changed

2 files changed

+49
-56
lines changed

pymc/logprob/order.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
from pytensor.graph.fg import FunctionGraph
4343
from pytensor.graph.rewriting.basic import node_rewriter
4444
from pytensor.scalar.basic import Mul
45+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
4546
from pytensor.tensor.elemwise import Elemwise
47+
from pytensor.tensor.exceptions import NotScalarConstantError
4648
from pytensor.tensor.math import Max
4749
from pytensor.tensor.random.op import RandomVariable
4850
from pytensor.tensor.var import TensorVariable
@@ -130,14 +132,15 @@ def max_logprob(op, values, base_rv, **kwargs):
130132

131133

132134
class MeasurableMaxNeg(Max):
133-
"""A placeholder used to specify a log-likelihood for a min sub-graph."""
135+
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
136+
This shows up in the graph of min, which is (neg(max(neg(x)))."""
134137

135138

136139
MeasurableVariable.register(MeasurableMaxNeg)
137140

138141

139142
@node_rewriter(tracks=[Max])
140-
def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
143+
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
141144
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
142145

143146
if rv_map_feature is None:
@@ -154,12 +157,19 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
154157
if not rv_map_feature.request_measurable(node.inputs):
155158
return None
156159

157-
# Min is the Max of the negation of the same distribution. Hence, op must be Elemiwise
160+
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
158161
if not isinstance(base_var.owner.op, Elemwise):
159162
return None
160163

161-
# negation is -1*(rv). Hence the scalar_op must be Mul
162-
if not isinstance(base_var.owner.op.scalar_op, Mul):
164+
# negation is rv * (-1). Hence the scalar_op must be Mul
165+
try:
166+
if not (
167+
isinstance(base_var.owner.op.scalar_op, Mul)
168+
and len(base_var.owner.inputs) == 2
169+
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
170+
):
171+
return None
172+
except NotScalarConstantError:
163173
return None
164174

165175
base_rv = base_var.owner.inputs[0]
@@ -191,8 +201,8 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
191201

192202

193203
measurable_ir_rewrites_db.register(
194-
"find_measurable_min",
195-
find_measurable_min,
204+
"find_measurable_max_neg",
205+
find_measurable_max_neg,
196206
"basic",
197207
"min",
198208
)

tests/logprob/test_order.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -58,104 +58,87 @@ def test_argmax():
5858

5959

6060
@pytest.mark.parametrize(
61-
"if_max",
61+
"pt_op",
6262
[
63-
True,
64-
False,
63+
pt.max,
64+
pt.min,
6565
],
6666
)
67-
def test_non_iid_fails(if_max):
67+
def test_non_iid_fails(pt_op):
6868
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
6969
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
7070
x.name = "x"
71-
if if_max == True:
72-
x_m = pt.max(x, axis=-1)
73-
x_m_value = pt.vector("x_max_value")
74-
else:
75-
x_m = pt.min(x, axis=-1)
76-
x_m_value = pt.vector("x_min_value")
71+
x_m = pt_op(x, axis=-1)
72+
x_m_value = pt.vector("x_value")
7773
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
7874
x_max_logprob = logp(x_m, x_m_value)
7975

8076

8177
@pytest.mark.parametrize(
82-
"if_max",
83-
[True, False],
78+
"pt_op",
79+
[
80+
pt.max,
81+
pt.min,
82+
],
8483
)
85-
def test_non_rv_fails(if_max):
84+
def test_non_rv_fails(pt_op):
8685
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
8786
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
8887
x.name = "x"
89-
if if_max == True:
90-
x_m = pt.max(x, axis=-1)
91-
x_m_value = pt.vector("x_max_value")
92-
else:
93-
x_m = pt.min(x, axis=-1)
94-
x_m_value = pt.vector("x_min_value")
88+
x_m = pt_op(x, axis=-1)
89+
x_m_value = pt.vector("x_value")
9590
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
9691
x_max_logprob = logp(x_m, x_m_value)
9792

9893

9994
@pytest.mark.parametrize(
100-
"if_max",
95+
"pt_op",
10196
[
102-
True,
103-
False,
97+
pt.max,
98+
pt.min,
10499
],
105100
)
106-
def test_multivariate_rv_fails(if_max):
101+
def test_multivariate_rv_fails(pt_op):
107102
_alpha = pt.scalar()
108103
_k = pt.iscalar()
109104
x = pm.StickBreakingWeights.dist(_alpha, _k)
110105
x.name = "x"
111-
if if_max == True:
112-
x_m = pt.max(x, axis=-1)
113-
x_m_value = pt.vector("x_max_value")
114-
else:
115-
x_m = pt.min(x, axis=-1)
116-
x_m_value = pt.vector("x_min_value")
106+
x_m = pt_op(x, axis=-1)
107+
x_m_value = pt.vector("x_value")
117108
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
118109
x_max_logprob = logp(x_m, x_m_value)
119110

120111

121112
@pytest.mark.parametrize(
122-
"if_max",
113+
"pt_op",
123114
[
124-
True,
125-
False,
115+
pt.max,
116+
pt.min,
126117
],
127118
)
128-
def test_categorical(if_max):
119+
def test_categorical(pt_op):
129120
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
130121
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
131122
x.name = "x"
132-
if if_max == True:
133-
x_m = pt.max(x, axis=-1)
134-
x_m_value = pt.vector("x_max_value")
135-
else:
136-
x_m = pt.min(x, axis=-1)
137-
x_m_value = pt.vector("x_min_value")
123+
x_m = pt_op(x, axis=-1)
124+
x_m_value = pt.vector("x_value")
138125
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
139126
x_max_logprob = logp(x_m, x_m_value)
140127

141128

142129
@pytest.mark.parametrize(
143-
"if_max",
130+
"pt_op",
144131
[
145-
True,
146-
False,
132+
pt.max,
133+
pt.min,
147134
],
148135
)
149-
def test_non_supp_axis(if_max):
136+
def test_non_supp_axis(pt_op):
150137
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
151138
x = pt.random.normal(0, 1, size=(3, 3))
152139
x.name = "x"
153-
if if_max == True:
154-
x_m = pt.max(x, axis=-1)
155-
x_m_value = pt.vector("x_max_value")
156-
else:
157-
x_m = pt.min(x, axis=-1)
158-
x_m_value = pt.vector("x_min_value")
140+
x_m = pt_op(x, axis=-1)
141+
x_m_value = pt.vector("x_value")
159142
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
160143
x_max_logprob = logp(x_m, x_m_value)
161144

0 commit comments

Comments
 (0)