File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node):
20722072 rval = [reciprocal (sqr (xsym ))]
20732073 if rval :
20742074 rval [0 ] = cast (rval [0 ], odtype )
2075- assert rval [0 ].type == node .outputs [0 ].type , (rval , node .outputs )
2075+ assert rval [0 ].type .is_super (node .outputs [0 ].type ), (
2076+ rval [0 ].type ,
2077+ node .outputs [0 ].type ,
2078+ )
20762079 return rval
20772080 else :
20782081 return False
Original file line number Diff line number Diff line change 9696 perform_sigm_times_exp ,
9797 simplify_mul ,
9898)
99- from pytensor .tensor .shape import Reshape , Shape_i
99+ from pytensor .tensor .shape import Reshape , Shape_i , SpecifyShape
100100from pytensor .tensor .type import (
101101 TensorType ,
102102 cmatrix ,
@@ -1671,6 +1671,13 @@ def test_local_pow_specialize():
16711671 assert isinstance (nodes [1 ].scalar_op , aes .basic .Reciprocal )
16721672 utt .assert_allclose (f (val_no0 ), val_no0 ** (- 0.5 ))
16731673
1674+ twos = np .full (shape = (10 ,), fill_value = 2.0 )
1675+ f = function ([v ], v ** twos )
1676+ ops = [node .op for node in f .maker .fgraph .toposort ()]
1677+ assert isinstance (ops [0 ], SpecifyShape )
1678+ assert ops [1 ] == sqr
1679+ utt .assert_allclose (f (val ), val ** twos )
1680+
16741681
16751682def test_local_pow_specialize_device_more_aggressive_on_cpu ():
16761683 mode = config .mode
You can’t perform that action at this time.
0 commit comments