|
30 | 30 | from pytensor.misc.safe_asarray import _asarray
|
31 | 31 | from pytensor.printing import debugprint
|
32 | 32 | from pytensor.tensor import inplace
|
33 |
| -from pytensor.tensor.basic import Alloc, join, switch |
| 33 | +from pytensor.tensor.basic import Alloc, join, second, switch |
34 | 34 | from pytensor.tensor.blas import Dot22, Gemv
|
35 | 35 | from pytensor.tensor.blas_c import CGemv
|
36 | 36 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
|
96 | 96 | perform_sigm_times_exp,
|
97 | 97 | simplify_mul,
|
98 | 98 | )
|
99 |
| -from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape |
| 99 | +from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape |
100 | 100 | from pytensor.tensor.type import (
|
101 | 101 | TensorType,
|
102 | 102 | cmatrix,
|
@@ -979,6 +979,28 @@ def test_mismatching_types(self):
|
979 | 979 | # No rewrite was applied
|
980 | 980 | assert z_rewritten is z
|
981 | 981 |
|
| 982 | + def test_shape_specified_by_constant(self): |
| 983 | + x = vector("x") |
| 984 | + const = np.full(shape=(5,), fill_value=2.0).astype(config.floatX) |
| 985 | + out = x * const |
| 986 | + |
| 987 | + new_out = rewrite_graph( |
| 988 | + out, custom_rewrite=in2out(local_mul_canonizer, name="test") |
| 989 | + ) |
| 990 | + expected_out = np.array([2.0]).astype(config.floatX) * specify_shape(x, (5,)) |
| 991 | + assert equal_computations([new_out], [expected_out]) |
| 992 | + |
| 993 | + def test_broadcasted_by_constant(self): |
| 994 | + x = vector("x") |
| 995 | + const = np.full(shape=(3, 5), fill_value=2.0).astype(config.floatX) |
| 996 | + out = x * const |
| 997 | + |
| 998 | + new_out = rewrite_graph( |
| 999 | + out, custom_rewrite=in2out(local_mul_canonizer, name="test") |
| 1000 | + ) |
| 1001 | + expected_out = second(const, np.array([[2.0]], dtype=config.floatX) * x) |
| 1002 | + assert equal_computations([new_out], [expected_out]) |
| 1003 | + |
982 | 1004 |
|
983 | 1005 | def test_local_merge_abs():
|
984 | 1006 | x, y, z = matrices("xyz")
|
|
0 commit comments