Skip to content

Commit fdd0eed

Browse files
committed
Fix type check is local_pow_specialize
1 parent 2f0ed25 commit fdd0eed

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node):
20722072
rval = [reciprocal(sqr(xsym))]
20732073
if rval:
20742074
rval[0] = cast(rval[0], odtype)
2075-
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
2075+
assert rval[0].type.is_super(node.outputs[0].type), (
2076+
rval[0].type,
2077+
node.outputs[0].type,
2078+
)
20762079
return rval
20772080
else:
20782081
return False

tests/tensor/rewriting/test_math.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
perform_sigm_times_exp,
9797
simplify_mul,
9898
)
99-
from pytensor.tensor.shape import Reshape, Shape_i
99+
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
100100
from pytensor.tensor.type import (
101101
TensorType,
102102
cmatrix,
@@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
16711671
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
16721672
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
16731673

1674+
twos = np.full(shape=(10,), fill_value=2.0)
1675+
f = function([v], v**twos, mode=mode)
1676+
topo = f.maker.fgraph.toposort()
1677+
assert len(topo) == 2
1678+
# Depending on the mode the SpecifyShape is lifted or not
1679+
if topo[0].op == sqr:
1680+
assert isinstance(topo[1].op, SpecifyShape)
1681+
else:
1682+
assert isinstance(topo[0].op, SpecifyShape)
1683+
assert topo[1].op == sqr
1684+
utt.assert_allclose(f(val), val**twos)
1685+
16741686

16751687
def test_local_pow_specialize_device_more_aggressive_on_cpu():
16761688
mode = config.mode

0 commit comments

Comments
 (0)