Skip to content

Allow fill_sink rewrite to accomodate changes in broadcastability #785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,7 @@ def local_fill_sink(fgraph, node):
# Check if we need to propagate the fill to the new outputs
# It's enough to check the first output, as Elemwise outputs must all have the same shapes
# Note: There are orderings that may require fewer fills.
old_bcast_pattern = node.outputs[0].type.broadcastable
models_iter = iter(models)
while old_bcast_pattern != outputs[0].type.broadcastable:
model = next(models_iter)
for model in models:
# Only apply this model if it would actually do anything
if broadcasted_by(outputs[0], model):
outputs = [fill(model, output) for output in outputs]
Expand Down
18 changes: 17 additions & 1 deletion tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytensor
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor import shared
from pytensor import graph_replace, shared
from pytensor.compile import optdb
from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
Expand Down Expand Up @@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client():
[new_out] = fg.outputs
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
assert equal_computations([new_out], [expected_out])


def test_topological_fill_sink_broadcastable_change():
"""Test rewrite doesn't fail after a graph replacement that provides a broadcastable change."""
a = vector("a", shape=(1,))
b = vector("b", shape=(1,))
zeros = pt.vector("zeros", shape=(None,))
initial_out = pt.full_like(zeros, a) + b

# Make broadcast to zeros irrelevant
out = graph_replace(initial_out, {zeros: pt.zeros((1,))}, strict=False)

fg = FunctionGraph([a, b], [out], copy_inputs=False)
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
assert equal_computations([new_out], [a + b])
Loading