|
8 | 8 |
|
9 | 9 | import pytensor
|
10 | 10 | import pytensor.scalar.basic as aes
|
| 11 | +from pytensor import config |
11 | 12 | from pytensor.gradient import (
|
12 | 13 | DisconnectedType,
|
13 | 14 | _float_zeros_like,
|
@@ -1598,24 +1599,27 @@ def broadcast_shape_iter(
|
1598 | 1599 | aes.get_scalar_type(dtype=v.dtype)()
|
1599 | 1600 | for v in scalar_maybe_non_bcast_shapes
|
1600 | 1601 | ]
|
1601 |
| - non_bcast_vec = [ |
1602 |
| - aes.switch(aes.eq(nbv, 1), -one_at, nbv) |
1603 |
| - for nbv in dummy_maybe_non_bcast_shapes |
1604 |
| - ] |
1605 |
| - dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) |
1606 |
| - dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) |
1607 |
| - |
1608 |
| - dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) |
1609 |
| - |
1610 |
| - assert_dim = Assert("Could not broadcast dimensions") |
1611 |
| - assert_cond = reduce( |
1612 |
| - aes.and_, |
1613 |
| - ( |
1614 |
| - aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) |
1615 |
| - for nbv in non_bcast_vec |
1616 |
| - ), |
1617 |
| - ) |
1618 |
| - assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond]) |
| 1602 | + with config.change_flags(compute_test_value="off"): |
| 1603 | + non_bcast_vec = [ |
| 1604 | + aes.switch(aes.eq(nbv, 1), -one_at, nbv) |
| 1605 | + for nbv in dummy_maybe_non_bcast_shapes |
| 1606 | + ] |
| 1607 | + dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) |
| 1608 | + dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) |
| 1609 | + |
| 1610 | + dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) |
| 1611 | + |
| 1612 | + assert_dim = Assert("Could not broadcast dimensions") |
| 1613 | + assert_cond = reduce( |
| 1614 | + aes.and_, |
| 1615 | + ( |
| 1616 | + aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) |
| 1617 | + for nbv in non_bcast_vec |
| 1618 | + ), |
| 1619 | + ) |
| 1620 | + assert_cond_op = Composite( |
| 1621 | + dummy_maybe_non_bcast_shapes, [assert_cond] |
| 1622 | + ) |
1619 | 1623 |
|
1620 | 1624 | bcast_dim = assert_dim(
|
1621 | 1625 | dim_max_op(*scalar_maybe_non_bcast_shapes),
|
|
0 commit comments