Skip to content

Commit 8c113f1

Browse files
committed
Don't compute (or complain about) test_value in the dummy inner graph of broadcast_shape_iter
1 parent 3d21687 commit 8c113f1

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytensor
1010
import pytensor.scalar.basic as aes
11+
from pytensor import config
1112
from pytensor.gradient import (
1213
DisconnectedType,
1314
_float_zeros_like,
@@ -1598,24 +1599,27 @@ def broadcast_shape_iter(
15981599
aes.get_scalar_type(dtype=v.dtype)()
15991600
for v in scalar_maybe_non_bcast_shapes
16001601
]
1601-
non_bcast_vec = [
1602-
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
1603-
for nbv in dummy_maybe_non_bcast_shapes
1604-
]
1605-
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
1606-
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
1607-
1608-
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
1609-
1610-
assert_dim = Assert("Could not broadcast dimensions")
1611-
assert_cond = reduce(
1612-
aes.and_,
1613-
(
1614-
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
1615-
for nbv in non_bcast_vec
1616-
),
1617-
)
1618-
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
1602+
with config.change_flags(compute_test_value="off"):
1603+
non_bcast_vec = [
1604+
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
1605+
for nbv in dummy_maybe_non_bcast_shapes
1606+
]
1607+
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
1608+
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
1609+
1610+
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
1611+
1612+
assert_dim = Assert("Could not broadcast dimensions")
1613+
assert_cond = reduce(
1614+
aes.and_,
1615+
(
1616+
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
1617+
for nbv in non_bcast_vec
1618+
),
1619+
)
1620+
assert_cond_op = Composite(
1621+
dummy_maybe_non_bcast_shapes, [assert_cond]
1622+
)
16191623

16201624
bcast_dim = assert_dim(
16211625
dim_max_op(*scalar_maybe_non_bcast_shapes),

0 commit comments

Comments
 (0)