Skip to content

Commit 371472d

Browse files
authored
Add logprob inference for not operations (#6689)
1 parent d4bb701 commit 371472d

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

pymc/logprob/binary.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from pytensor.graph.basic import Node
2020
from pytensor.graph.fg import FunctionGraph
2121
from pytensor.graph.rewriting.basic import node_rewriter
22-
from pytensor.scalar.basic import GE, GT, LE, LT
22+
from pytensor.scalar.basic import GE, GT, LE, LT, Invert
2323
from pytensor.tensor import TensorVariable
24-
from pytensor.tensor.math import ge, gt, le, lt
24+
from pytensor.tensor.math import ge, gt, invert, le, lt
2525

2626
from pymc.logprob.abstract import (
2727
MeasurableElemwise,
@@ -136,3 +136,57 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
136136
logcdf.name = f"{base_rv_op}_logcdf"
137137

138138
return logprob
139+
140+
141+
class MeasurableBitwise(MeasurableElemwise):
142+
"""A placeholder used to specify a log-likelihood for a bitwise operation RV sub-graph."""
143+
144+
valid_scalar_types = (Invert,)
145+
146+
147+
@node_rewriter(tracks=[invert])
148+
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableBitwise]]:
149+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
150+
if rv_map_feature is None:
151+
return None # pragma: no cover
152+
153+
if isinstance(node.op, MeasurableBitwise):
154+
return None # pragma: no cover
155+
156+
base_var = node.inputs[0]
157+
if not (
158+
base_var.owner
159+
and isinstance(base_var.owner.op, MeasurableVariable)
160+
and base_var not in rv_map_feature.rv_values
161+
):
162+
return None
163+
164+
if not base_var.dtype.startswith("bool"):
165+
raise None
166+
167+
# Make base_var unmeasurable
168+
unmeasurable_base_var = ignore_logprob(base_var)
169+
170+
node_scalar_op = node.op.scalar_op
171+
172+
bitwise_op = MeasurableBitwise(node_scalar_op)
173+
bitwise_rv = bitwise_op.make_node(unmeasurable_base_var).default_output()
174+
bitwise_rv.name = node.outputs[0].name
175+
return [bitwise_rv]
176+
177+
178+
measurable_ir_rewrites_db.register(
179+
"find_measurable_bitwise",
180+
find_measurable_bitwise,
181+
"basic",
182+
"bitwise",
183+
)
184+
185+
186+
@_logprob.register(MeasurableBitwise)
187+
def bitwise_not_logprob(op, values, base_rv, **kwargs):
188+
(value,) = values
189+
190+
logprob = _logprob_helper(base_rv, invert(value), **kwargs)
191+
192+
return logprob

tests/logprob/test_binary.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))),
3434
],
3535
)
36-
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, inputs):
36+
def test_continuous_rv_comparison_bitwise(comparison_op, exp_logp_true, exp_logp_false, inputs):
3737
for op in comparison_op:
3838
comp_x_rv = op(*inputs)
3939

@@ -48,6 +48,17 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false,
4848
assert np.isclose(logp_fn(0), getattr(ref_scipy, exp_logp_false)(0.5))
4949
assert np.isclose(logp_fn(1), getattr(ref_scipy, exp_logp_true)(0.5))
5050

51+
bitwise_rv = pt.bitwise_not(op(*inputs))
52+
bitwise_vv = bitwise_rv.clone()
53+
54+
logprob_not = logp(bitwise_rv, bitwise_vv)
55+
assert_no_rvs(logprob_not)
56+
57+
logp_fn_not = pytensor.function([bitwise_vv], logprob_not)
58+
59+
assert np.isclose(logp_fn_not(0), getattr(ref_scipy, exp_logp_true)(0.5))
60+
assert np.isclose(logp_fn_not(1), getattr(ref_scipy, exp_logp_false)(0.5))
61+
5162

5263
@pytest.mark.parametrize(
5364
"comparison_op, exp_logp_true, exp_logp_false, inputs",
@@ -87,7 +98,7 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false,
8798
),
8899
],
89100
)
90-
def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_false):
101+
def test_discrete_rv_comparison_bitwise(inputs, comparison_op, exp_logp_true, exp_logp_false):
91102
cens_x_rv = comparison_op(*inputs)
92103

93104
cens_x_vv = cens_x_rv.clone()
@@ -100,6 +111,17 @@ def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_f
100111
assert np.isclose(logp_fn(1), exp_logp_true(3))
101112
assert np.isclose(logp_fn(0), exp_logp_false(3))
102113

114+
bitwise_rv = pt.bitwise_not(comparison_op(*inputs))
115+
bitwise_vv = bitwise_rv.clone()
116+
117+
logprob_not = logp(bitwise_rv, bitwise_vv)
118+
assert_no_rvs(logprob_not)
119+
120+
logp_fn_not = pytensor.function([bitwise_vv], logprob_not)
121+
122+
assert np.isclose(logp_fn_not(1), exp_logp_false(3))
123+
assert np.isclose(logp_fn_not(0), exp_logp_true(3))
124+
103125

104126
def test_potentially_measurable_operand():
105127
x_rv = pt.random.normal(2)

0 commit comments

Comments
 (0)