Skip to content

Commit 7775b9e

Browse files
committed
Don't introduce Second in AlegrabicCanonizer because of shape specialization (only for broadcasting)
1 parent 5bb4cf0 commit 7775b9e

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,10 +1135,12 @@ def same(x, y):
11351135
if new.type.dtype != out.type.dtype:
11361136
new = cast(new, out.type.dtype)
11371137

1138-
if new.type != out.type:
1138+
if new.type.broadcastable != out.type.broadcastable:
11391139
new = fill_chain(new, node.inputs)[0]
11401140

1141-
if new.type == out.type:
1141+
if (new.type.dtype == out.type.dtype) and (
1142+
new.type.broadcastable == out.type.broadcastable
1143+
):
11421144
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
11431145
copy_stack_trace(out, new)
11441146
return [new]

tests/tensor/rewriting/test_math.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytensor.misc.safe_asarray import _asarray
3131
from pytensor.printing import debugprint
3232
from pytensor.tensor import inplace
33-
from pytensor.tensor.basic import Alloc, join, switch
33+
from pytensor.tensor.basic import Alloc, join, second, switch
3434
from pytensor.tensor.blas import Dot22, Gemv
3535
from pytensor.tensor.blas_c import CGemv
3636
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -96,7 +96,7 @@
9696
perform_sigm_times_exp,
9797
simplify_mul,
9898
)
99-
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
99+
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
100100
from pytensor.tensor.type import (
101101
TensorType,
102102
cmatrix,
@@ -979,6 +979,28 @@ def test_mismatching_types(self):
979979
# No rewrite was applied
980980
assert z_rewritten is z
981981

982+
def test_shape_specified_by_constant(self):
983+
x = vector("x")
984+
const = np.full(shape=(5,), fill_value=2.0).astype(config.floatX)
985+
out = x * const
986+
987+
new_out = rewrite_graph(
988+
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
989+
)
990+
expected_out = np.array([2.0]).astype(config.floatX) * specify_shape(x, (5,))
991+
assert equal_computations([new_out], [expected_out])
992+
993+
def test_broadcasted_by_constant(self):
994+
x = vector("x")
995+
const = np.full(shape=(3, 5), fill_value=2.0).astype(config.floatX)
996+
out = x * const
997+
998+
new_out = rewrite_graph(
999+
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
1000+
)
1001+
expected_out = second(const, np.array([[2.0]], dtype=config.floatX) * x)
1002+
assert equal_computations([new_out], [expected_out])
1003+
9821004

9831005
def test_local_merge_abs():
9841006
x, y, z = matrices("xyz")

0 commit comments

Comments
 (0)