Skip to content

Commit 0d3126a

Browse files
committed
Fix type check is local_pow_specialize
1 parent 2f0ed25 commit 0d3126a

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node):
20722072
rval = [reciprocal(sqr(xsym))]
20732073
if rval:
20742074
rval[0] = cast(rval[0], odtype)
2075-
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
2075+
assert rval[0].type.is_super(node.outputs[0].type), (
2076+
rval[0].type,
2077+
node.outputs[0].type,
2078+
)
20762079
return rval
20772080
else:
20782081
return False

tests/tensor/rewriting/test_math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
perform_sigm_times_exp,
9797
simplify_mul,
9898
)
99-
from pytensor.tensor.shape import Reshape, Shape_i
99+
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
100100
from pytensor.tensor.type import (
101101
TensorType,
102102
cmatrix,
@@ -1671,6 +1671,13 @@ def test_local_pow_specialize():
16711671
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
16721672
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
16731673

1674+
twos = np.full(shape=(10,), fill_value=2.0)
1675+
f = function([v], v**twos)
1676+
ops = [node.op for node in f.maker.fgraph.toposort()]
1677+
assert isinstance(ops[0], SpecifyShape)
1678+
assert ops[1] == sqr
1679+
utt.assert_allclose(f(val), val**twos)
1680+
16741681

16751682
def test_local_pow_specialize_device_more_aggressive_on_cpu():
16761683
mode = config.mode

0 commit comments

Comments
 (0)