|
66 | 66 | from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
|
67 | 67 | from pymc.logprob.rewriting import (
|
68 | 68 | PreserveRVMappings,
|
| 69 | + assume_measured_ir_outputs, |
69 | 70 | local_lift_DiracDelta,
|
70 | 71 | measurable_ir_rewrites_db,
|
71 | 72 | subtensor_ops,
|
72 | 73 | )
|
73 | 74 | from pymc.logprob.tensor import naive_bcast_rv_lift
|
| 75 | +from pymc.logprob.utils import check_potential_measurability |
74 | 76 |
|
75 | 77 |
|
76 | 78 | def is_newaxis(x):
|
@@ -453,19 +455,66 @@ class MeasurableIfElse(IfElse):
|
453 | 455 | MeasurableVariable.register(MeasurableIfElse)
|
454 | 456 |
|
455 | 457 |
|
| 458 | +@node_rewriter([IfElse]) |
| 459 | +def useless_ifelse_outputs(fgraph, node): |
| 460 | + """Remove outputs that are shared across the IfElse branches.""" |
| 461 | + # TODO: This should be a PyTensor canonicalization |
| 462 | + op = node.op |
| 463 | + if_var, *inputs = node.inputs |
| 464 | + shared_inputs = set(inputs[op.n_outs :]).intersection(inputs[: op.n_outs]) |
| 465 | + if not shared_inputs: |
| 466 | + return None |
| 467 | + |
| 468 | + replacements = {} |
| 469 | + for shared_inp in shared_inputs: |
| 470 | + idx = inputs.index(shared_inp) |
| 471 | + replacements[node.outputs[idx]] = shared_inp |
| 472 | + |
| 473 | + # IfElse isn't needed at all |
| 474 | + if len(shared_inputs) == op.n_outs: |
| 475 | + return replacements |
| 476 | + |
| 477 | + # Create subset IfElse with remaining nodes |
| 478 | + remaining_inputs = [inp for inp in inputs if inp not in shared_inputs] |
| 479 | + new_outs = ( |
| 480 | + IfElse(n_outs=len(remaining_inputs) // 2).make_node(if_var, *remaining_inputs).outputs |
| 481 | + ) |
| 482 | + for inp, new_out in zip(remaining_inputs, new_outs): |
| 483 | + idx = inputs.index(inp) |
| 484 | + replacements[node.outputs[idx]] = new_out |
| 485 | + |
| 486 | + return replacements |
| 487 | + |
| 488 | + |
456 | 489 | @node_rewriter([IfElse])
|
457 | 490 | def find_measurable_ifelse_mixture(fgraph, node):
|
458 | 491 | rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
|
459 | 492 |
|
460 | 493 | if rv_map_feature is None:
|
461 | 494 | return None # pragma: no cover
|
462 | 495 |
|
| 496 | + op = node.op |
463 | 497 | if_var, *base_rvs = node.inputs
|
464 | 498 |
|
465 |
| - if rv_map_feature.request_measurable(base_rvs) != base_rvs: |
| 499 | + valued_rvs = rv_map_feature.rv_values.keys() |
| 500 | + if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_rvs): |
466 | 501 | return None
|
467 | 502 |
|
468 |
| - return MeasurableIfElse(n_outs=node.op.n_outs).make_node(if_var, *base_rvs).outputs |
| 503 | + base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs) |
| 504 | + if len(base_rvs) != op.n_outs * 2: |
| 505 | + return None |
| 506 | + if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): |
| 507 | + return None |
| 508 | + |
| 509 | + return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs |
| 510 | + |
| 511 | + |
| 512 | +measurable_ir_rewrites_db.register( |
| 513 | + "useless_ifelse_outputs", |
| 514 | + useless_ifelse_outputs, |
| 515 | + "basic", |
| 516 | + "mixture", |
| 517 | +) |
469 | 518 |
|
470 | 519 |
|
471 | 520 | measurable_ir_rewrites_db.register(
|
|
0 commit comments