File tree 2 files changed +12
-2
lines changed
pytensor/tensor/rewriting
2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node):
2072
2072
rval = [reciprocal (sqr (xsym ))]
2073
2073
if rval :
2074
2074
rval [0 ] = cast (rval [0 ], odtype )
2075
- assert rval [0 ].type == node .outputs [0 ].type , (rval , node .outputs )
2075
+ assert rval [0 ].type .is_super (node .outputs [0 ].type ), (
2076
+ rval [0 ].type ,
2077
+ node .outputs [0 ].type ,
2078
+ )
2076
2079
return rval
2077
2080
else :
2078
2081
return False
Original file line number Diff line number Diff line change 96
96
perform_sigm_times_exp ,
97
97
simplify_mul ,
98
98
)
99
- from pytensor .tensor .shape import Reshape , Shape_i
99
+ from pytensor .tensor .shape import Reshape , Shape_i , SpecifyShape
100
100
from pytensor .tensor .type import (
101
101
TensorType ,
102
102
cmatrix ,
@@ -1671,6 +1671,13 @@ def test_local_pow_specialize():
1671
1671
assert isinstance (nodes [1 ].scalar_op , aes .basic .Reciprocal )
1672
1672
utt .assert_allclose (f (val_no0 ), val_no0 ** (- 0.5 ))
1673
1673
1674
+ twos = np .full (shape = (10 ,), fill_value = 2.0 )
1675
+ f = function ([v ], v ** twos )
1676
+ ops = [node .op for node in f .maker .fgraph .toposort ()]
1677
+ assert isinstance (ops [0 ], SpecifyShape )
1678
+ assert ops [1 ] == sqr
1679
+ utt .assert_allclose (f (val ), val ** twos )
1680
+
1674
1681
1675
1682
def test_local_pow_specialize_device_more_aggressive_on_cpu ():
1676
1683
mode = config .mode
You can’t perform that action at this time.
0 commit comments