Skip to content

Commit 885ff0c

Browse files
committed
fix a test that expected broadcasting to happen
1 parent 9e98224 commit 885ff0c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/tensor/rewriting/test_basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,8 +1693,12 @@ def verify_op_count(f, count, cls):
16931693
],
16941694
)
16951695
def test_basic(self, expr, x_shape, y_shape):
1696-
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
1697-
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
1696+
x = at.tensor(
1697+
dtype="int64", broadcastable=tuple(s == 1 for s in x_shape), name="x"
1698+
)
1699+
y = at.tensor(
1700+
dtype="int64", broadcastable=tuple(s == 1 for s in y_shape), name="y"
1701+
)
16981702
z = expr(x, y)
16991703

17001704
z_opt = pytensor.function(

0 commit comments

Comments
 (0)