Skip to content

Commit a30e0d4

Browse files
ricardoV94twiecki
authored andcommitted
Allow composition of interdependent container variables
Join/MakeVector/IfElse can output multiple interdependent variables. These are potentially measurable because in the logp each output is given a distinct value variable. However, this isn't known during the IR rewrites. To circumvent this issue, we run an inner IR rewrite after giving dummy value variables to each output
1 parent 9bba026 commit a30e0d4

File tree

9 files changed

+181
-18
lines changed

9 files changed

+181
-18
lines changed

pymc/logprob/binary.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def find_measurable_comparisons(
6060
const = node.inputs[(measurable_var_idx + 1) % 2]
6161

6262
# check for potential measurability of const
63-
if not check_potential_measurability([const], rv_map_feature):
63+
if check_potential_measurability([const], rv_map_feature.rv_values.keys()):
6464
return None
6565

6666
node_scalar_op = node.op.scalar_op

pymc/logprob/mixture.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@
6666
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
6767
from pymc.logprob.rewriting import (
6868
PreserveRVMappings,
69+
assume_measured_ir_outputs,
6970
local_lift_DiracDelta,
7071
measurable_ir_rewrites_db,
7172
subtensor_ops,
7273
)
7374
from pymc.logprob.tensor import naive_bcast_rv_lift
75+
from pymc.logprob.utils import check_potential_measurability
7476

7577

7678
def is_newaxis(x):
@@ -453,19 +455,66 @@ class MeasurableIfElse(IfElse):
453455
MeasurableVariable.register(MeasurableIfElse)
454456

455457

458+
@node_rewriter([IfElse])
459+
def useless_ifelse_outputs(fgraph, node):
460+
"""Remove outputs that are shared across the IfElse branches."""
461+
# TODO: This should be a PyTensor canonicalization
462+
op = node.op
463+
if_var, *inputs = node.inputs
464+
shared_inputs = set(inputs[op.n_outs :]).intersection(inputs[: op.n_outs])
465+
if not shared_inputs:
466+
return None
467+
468+
replacements = {}
469+
for shared_inp in shared_inputs:
470+
idx = inputs.index(shared_inp)
471+
replacements[node.outputs[idx]] = shared_inp
472+
473+
# IfElse isn't needed at all
474+
if len(shared_inputs) == op.n_outs:
475+
return replacements
476+
477+
# Create subset IfElse with remaining nodes
478+
remaining_inputs = [inp for inp in inputs if inp not in shared_inputs]
479+
new_outs = (
480+
IfElse(n_outs=len(remaining_inputs) // 2).make_node(if_var, *remaining_inputs).outputs
481+
)
482+
for inp, new_out in zip(remaining_inputs, new_outs):
483+
idx = inputs.index(inp)
484+
replacements[node.outputs[idx]] = new_out
485+
486+
return replacements
487+
488+
456489
@node_rewriter([IfElse])
457490
def find_measurable_ifelse_mixture(fgraph, node):
458491
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
459492

460493
if rv_map_feature is None:
461494
return None # pragma: no cover
462495

496+
op = node.op
463497
if_var, *base_rvs = node.inputs
464498

465-
if rv_map_feature.request_measurable(base_rvs) != base_rvs:
499+
valued_rvs = rv_map_feature.rv_values.keys()
500+
if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_rvs):
466501
return None
467502

468-
return MeasurableIfElse(n_outs=node.op.n_outs).make_node(if_var, *base_rvs).outputs
503+
base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs)
504+
if len(base_rvs) != op.n_outs * 2:
505+
return None
506+
if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs):
507+
return None
508+
509+
return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs
510+
511+
512+
measurable_ir_rewrites_db.register(
513+
"useless_ifelse_outputs",
514+
useless_ifelse_outputs,
515+
"basic",
516+
"mixture",
517+
)
469518

470519

471520
measurable_ir_rewrites_db.register(

pymc/logprob/rewriting.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,16 @@
4242

4343
from pytensor import config
4444
from pytensor.compile.mode import optdb
45-
from pytensor.graph.basic import Constant, Variable, ancestors, io_toposort
45+
from pytensor.graph.basic import (
46+
Constant,
47+
Variable,
48+
ancestors,
49+
io_toposort,
50+
truncated_graph_inputs,
51+
)
4652
from pytensor.graph.features import Feature
4753
from pytensor.graph.fg import FunctionGraph
54+
from pytensor.graph.replace import clone_replace
4855
from pytensor.graph.rewriting.basic import (
4956
ChangeTracker,
5057
EquilibriumGraphRewriter,
@@ -461,3 +468,34 @@ def cleanup_ir(vars: Sequence[Variable]) -> None:
461468
fgraph = FunctionGraph(outputs=vars, clone=False)
462469
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"]))
463470
ir_rewriter.rewrite(fgraph)
471+
472+
473+
def assume_measured_ir_outputs(
474+
inputs: Sequence[TensorVariable], outputs: Sequence[TensorVariable]
475+
) -> Sequence[TensorVariable]:
476+
"""Run IR rewrite assuming each output is measured.
477+
478+
IR variables could depend on each other in a way that looks unmeasurable without a value variable assigned to each.
479+
For instance `join([add(x, z), z])` is a potentially measurable join, but `add(x, z)` can look unmeasurable
480+
because neither `x` and `z` are valued in the IR representation.
481+
This helper runs an inner ir rewrite after giving each output a dummy value variable.
482+
We replace inputs by dummies and then undo it so that any dependency on outer variables is preserved.
483+
"""
484+
# Replace inputs by dummy variables
485+
replaced_inputs = {
486+
var: var.type()
487+
for var in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
488+
if var in inputs
489+
}
490+
cloned_outputs = clone_replace(outputs, replace=replaced_inputs)
491+
492+
dummy_rv_values = {base_var: base_var.type() for base_var in cloned_outputs}
493+
fgraph, *_ = construct_ir_fgraph(dummy_rv_values)
494+
495+
# Replace dummy variables by inputs
496+
fgraph.replace_all(
497+
tuple((repl, orig) for orig, repl in replaced_inputs.items()),
498+
import_missing=True,
499+
)
500+
501+
return fgraph.outputs

pymc/logprob/tensor.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@
5151
)
5252

5353
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
54-
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
54+
from pymc.logprob.rewriting import (
55+
PreserveRVMappings,
56+
assume_measured_ir_outputs,
57+
measurable_ir_rewrites_db,
58+
)
59+
from pymc.logprob.utils import check_potential_measurability
5560

5661

5762
@node_rewriter([BroadcastTo])
@@ -213,7 +218,12 @@ def find_measurable_stacks(
213218
else:
214219
base_vars = node.inputs
215220

216-
if rv_map_feature.request_measurable(base_vars) != base_vars:
221+
valued_rvs = rv_map_feature.rv_values.keys()
222+
if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_vars):
223+
return None
224+
225+
base_vars = assume_measured_ir_outputs(valued_rvs, base_vars)
226+
if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars):
217227
return None
218228

219229
if is_join:

pymc/logprob/transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
529529
# would be invalid
530530
other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input)
531531

532-
if not check_potential_measurability(other_inputs, rv_map_feature):
532+
if check_potential_measurability(other_inputs, rv_map_feature.rv_values.keys()):
533533
return None
534534

535535
scalar_op = node.op.scalar_op

pymc/logprob/utils.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from typing import (
4040
Callable,
41+
Container,
4142
Dict,
4243
Generator,
4344
Iterable,
@@ -210,22 +211,24 @@ def indices_from_subtensor(idx_list, indices):
210211
)
211212

212213

213-
def check_potential_measurability(inputs: Tuple[TensorVariable], rv_map_feature):
214+
def check_potential_measurability(
215+
inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable]
216+
) -> bool:
214217
if any(
215-
ancestor_node
216-
for ancestor_node in walk_model(
218+
ancestor_var
219+
for ancestor_var in walk_model(
217220
inputs,
218221
walk_past_rvs=False,
219-
stop_at_vars=set(rv_map_feature.rv_values),
222+
stop_at_vars=set(valued_rvs),
220223
)
221224
if (
222-
ancestor_node.owner
223-
and isinstance(ancestor_node.owner.op, MeasurableVariable)
224-
and ancestor_node not in rv_map_feature.rv_values
225+
ancestor_var.owner
226+
and isinstance(ancestor_var.owner.op, MeasurableVariable)
227+
and ancestor_var not in valued_rvs
225228
)
226229
):
227-
return None
228-
return True
230+
return True
231+
return False
229232

230233

231234
class ParameterValueError(ValueError):

tests/logprob/test_composite_logprob.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import pytest
4141
import scipy.stats as st
4242

43-
from pymc import logp
43+
from pymc import draw, logp
4444
from pymc.logprob.abstract import MeasurableVariable
4545
from pymc.logprob.basic import factorized_joint_logprob
4646
from pymc.logprob.censoring import MeasurableClip
@@ -218,3 +218,51 @@ def test_affine_log_transform_rv():
218218
logp_fn(a_val, b_val, y_val),
219219
st.norm(a_val, b_val).logpdf(y_val),
220220
)
221+
222+
223+
@pytest.mark.parametrize("reverse", (False, True))
224+
def test_affine_join_interdependent(reverse):
225+
x = pt.random.normal(name="x")
226+
y_rvs = []
227+
prev_rv = x
228+
for i in range(3):
229+
next_rv = pt.exp(prev_rv + pt.random.beta(3, 1, name=f"y{i}", size=(1, 2)))
230+
y_rvs.append(next_rv)
231+
prev_rv = next_rv
232+
233+
if reverse:
234+
y_rvs = y_rvs[::-1]
235+
236+
ys = pt.concatenate(y_rvs, axis=0)
237+
ys.name = "ys"
238+
239+
x_vv = x.clone()
240+
ys_vv = ys.clone()
241+
242+
logp = factorized_joint_logprob({x: x_vv, ys: ys_vv})
243+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
244+
assert_no_rvs(logp_combined)
245+
246+
y0_vv = y_rvs[0].clone()
247+
y1_vv = y_rvs[1].clone()
248+
y2_vv = y_rvs[2].clone()
249+
250+
ref_logp = factorized_joint_logprob(
251+
{x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv}
252+
)
253+
ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()])
254+
255+
rng = np.random.default_rng()
256+
x_vv_test, ys_vv_test = draw([x, ys], random_seed=rng)
257+
ys_vv_test = rng.normal(size=(3, 2))
258+
np.testing.assert_allclose(
259+
logp_combined.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
260+
ref_logp_combined.eval(
261+
{
262+
x_vv: x_vv_test,
263+
y0_vv: ys_vv_test[0:1],
264+
y1_vv: ys_vv_test[1:2],
265+
y2_vv: ys_vv_test[2:3],
266+
}
267+
),
268+
)

tests/logprob/test_mixture.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def test_ifelse_mixture_multiple_components():
984984

985985
if_var = pt.scalar("if_var", dtype="bool")
986986
comp_then1 = pt.random.normal(size=(2,), name="comp_true1")
987-
comp_then2 = pt.random.normal(comp_then1, size=(2, 2), name="comp_then2")
987+
comp_then2 = comp_then1 + pt.random.normal(size=(2, 2), name="comp_then2")
988988
comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1")
989989
comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2")
990990

tests/logprob/test_utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pymc.logprob.basic import joint_logp, logp
5252
from pymc.logprob.utils import (
5353
ParameterValueError,
54+
check_potential_measurability,
5455
dirac_delta,
5556
rvs_to_value_vars,
5657
walk_model,
@@ -194,3 +195,17 @@ def scipy_logprob(obs, c):
194195
return 0.0 if obs == c else -np.inf
195196

196197
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)
198+
199+
200+
def test_check_potential_measurability():
201+
x1 = pt.random.normal()
202+
x2 = pt.random.normal()
203+
x3 = pt.scalar("x3")
204+
y = pt.exp(x1 + x2 + x3)
205+
206+
# In the first three cases, y is potentially measurable, because it has at least on unvalued RV input
207+
assert check_potential_measurability([y], {})
208+
assert check_potential_measurability([y], {x1})
209+
assert check_potential_measurability([y], {x2})
210+
# y is not potentially measurable because both RV inputs are valued
211+
assert not check_potential_measurability([y], {x1, x2})

0 commit comments

Comments
 (0)