diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index d9aeb3b57d..6ad0925025 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -19,9 +19,9 @@ from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import GE, GT, LE, LT +from pytensor.scalar.basic import GE, GT, LE, LT, Invert from pytensor.tensor import TensorVariable -from pytensor.tensor.math import ge, gt, le, lt +from pytensor.tensor.math import ge, gt, invert, le, lt from pymc.logprob.abstract import ( MeasurableElemwise, @@ -136,3 +136,57 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): logcdf.name = f"{base_rv_op}_logcdf" return logprob + + +class MeasurableBitwise(MeasurableElemwise): + """A placeholder used to specify a log-likelihood for a bitwise operation RV sub-graph.""" + + valid_scalar_types = (Invert,) + + +@node_rewriter(tracks=[invert]) +def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableBitwise]]: + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableBitwise): + return None # pragma: no cover + + base_var = node.inputs[0] + if not ( + base_var.owner + and isinstance(base_var.owner.op, MeasurableVariable) + and base_var not in rv_map_feature.rv_values + ): + return None + + if not base_var.dtype.startswith("bool"): + raise None + + # Make base_var unmeasurable + unmeasurable_base_var = ignore_logprob(base_var) + + node_scalar_op = node.op.scalar_op + + bitwise_op = MeasurableBitwise(node_scalar_op) + bitwise_rv = bitwise_op.make_node(unmeasurable_base_var).default_output() + bitwise_rv.name = node.outputs[0].name + return [bitwise_rv] + + +measurable_ir_rewrites_db.register( + "find_measurable_bitwise", + find_measurable_bitwise, + "basic", + "bitwise", +) + + +@_logprob.register(MeasurableBitwise) +def bitwise_not_logprob(op, values, base_rv, **kwargs): + (value,) = values + + logprob = _logprob_helper(base_rv, invert(value), **kwargs) + + return logprob diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index 6f2248075c..1a91abdd6f 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -33,7 +33,7 @@ ((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))), ], ) -def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, inputs): +def test_continuous_rv_comparison_bitwise(comparison_op, exp_logp_true, exp_logp_false, inputs): for op in comparison_op: comp_x_rv = op(*inputs) @@ -48,6 +48,17 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, assert np.isclose(logp_fn(0), getattr(ref_scipy, exp_logp_false)(0.5)) assert np.isclose(logp_fn(1), getattr(ref_scipy, exp_logp_true)(0.5)) + bitwise_rv = pt.bitwise_not(op(*inputs)) + bitwise_vv = bitwise_rv.clone() + + logprob_not = logp(bitwise_rv, bitwise_vv) + assert_no_rvs(logprob_not) + + logp_fn_not = pytensor.function([bitwise_vv], logprob_not) + + assert np.isclose(logp_fn_not(0), getattr(ref_scipy, exp_logp_true)(0.5)) + assert np.isclose(logp_fn_not(1), getattr(ref_scipy, exp_logp_false)(0.5)) + @pytest.mark.parametrize( "comparison_op, exp_logp_true, exp_logp_false, inputs", @@ -87,7 +98,7 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, ), ], ) -def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_false): +def test_discrete_rv_comparison_bitwise(inputs, comparison_op, exp_logp_true, exp_logp_false): cens_x_rv = comparison_op(*inputs) cens_x_vv = cens_x_rv.clone() @@ -100,6 +111,17 @@ def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_f assert np.isclose(logp_fn(1), exp_logp_true(3)) assert np.isclose(logp_fn(0), exp_logp_false(3)) + bitwise_rv = pt.bitwise_not(comparison_op(*inputs)) + bitwise_vv = bitwise_rv.clone() + + logprob_not = logp(bitwise_rv, bitwise_vv) + assert_no_rvs(logprob_not) + + logp_fn_not = pytensor.function([bitwise_vv], logprob_not) + + assert np.isclose(logp_fn_not(1), exp_logp_false(3)) + assert np.isclose(logp_fn_not(0), exp_logp_true(3)) + def test_potentially_measurable_operand(): x_rv = pt.random.normal(2)