Skip to content

Commit f1d19f7

Browse files
committed
Incorporate static shape of Alloc input
1 parent 2f8697e commit f1d19f7

File tree

3 files changed

+39
-28
lines changed

3 files changed

+39
-28
lines changed

pytensor/tensor/basic.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -1432,17 +1432,33 @@ class Alloc(COp):
14321432
__props__ = ()
14331433

14341434
def make_node(self, value, *shape):
1435-
v = as_tensor_variable(value)
1436-
sh, static_shape = infer_static_shape(shape)
1437-
if v.ndim > len(sh):
1435+
value = as_tensor_variable(value)
1436+
shape, static_shape = infer_static_shape(shape)
1437+
if value.ndim > len(shape):
14381438
raise TypeError(
14391439
"The Alloc value to use has more dimensions"
14401440
" than the specified dimensions",
1441-
v.ndim,
1442-
len(sh),
1441+
value.ndim,
1442+
len(shape),
14431443
)
1444-
otype = TensorType(dtype=v.dtype, shape=static_shape)
1445-
return Apply(self, [v] + sh, [otype()])
1444+
1445+
# Combine static shape information from value and shape
1446+
combined_static_shape = list(static_shape).copy()
1447+
new_dims = len(shape) - value.type.ndim
1448+
extended_value_static_shape = (None,) * new_dims + (value.type.shape)
1449+
for i, (v_st, sh_st) in enumerate(
1450+
zip(extended_value_static_shape, static_shape)
1451+
):
1452+
if (v_st not in (1, None)) and (sh_st is None):
1453+
combined_static_shape[i] = v_st
1454+
elif (v_st is not None) and (sh_st is not None):
1455+
if v_st != sh_st and v_st != 1:
1456+
raise ValueError(
1457+
f"Alloc static input shape and target shape are incompatible: {value.type.shape} vs {static_shape}"
1458+
)
1459+
1460+
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
1461+
return Apply(self, [value] + shape, [otype()])
14461462

14471463
def perform(self, node, inputs, out_):
14481464
(out,) = out_

tests/tensor/rewriting/test_basic.py

-21
Original file line numberDiff line numberDiff line change
@@ -272,27 +272,6 @@ class TestLocalCanonicalizeAlloc:
272272
def setup_method(self):
273273
self.rng = np.random.default_rng(utt.fetch_seed())
274274

275-
def test_inconsistent_constant(self):
276-
x = at.as_tensor(self.rng.standard_normal((3, 7)))
277-
a = at.alloc(x, 6, 7)
278-
279-
assert a.owner and isinstance(a.owner.op, Alloc)
280-
281-
# `local_useless_alloc` should attempt to replace the `Alloc` with an
282-
# `Assert` and fail when the static shape information conflicts.
283-
with pytest.raises(TypeError):
284-
f = function([], a, mode=rewrite_mode)
285-
286-
x = at.as_tensor(self.rng.standard_normal((6, 7)))
287-
a = at.alloc(x, 6, 7)
288-
289-
f = function([], a, mode=rewrite_mode)
290-
291-
# The rewrite should then be applied, and remove Alloc
292-
assert not any(
293-
isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort()
294-
)
295-
296275
def test_inconsistent_shared(self):
297276
# These shapes don't match!
298277
x = shared(self.rng.standard_normal((3, 7)))

tests/tensor/test_basic.py

+16
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,22 @@ def test_rebuild(self, func):
835835
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
836836
assert y_new.eval({x_new: x_new_test}).shape == (100,)
837837

838+
def test_static_shape(self):
839+
x = tensor(shape=(None, 1, 5))
840+
d0 = scalar("d0", dtype=int)
841+
d1 = scalar("d1", dtype=int)
842+
assert at.alloc(x, 3, 1, 5).type.shape == (3, 1, 5)
843+
assert at.alloc(x, 3, 4, 5).type.shape == (3, 4, 5)
844+
assert at.alloc(x, d0, d1, 5).type.shape == (None, None, 5)
845+
assert at.alloc(x, d0, 1, d1).type.shape == (None, 1, 5)
846+
847+
msg = "Alloc static input shape and target shape are incompatible"
848+
with pytest.raises(ValueError, match=msg):
849+
at.alloc(x, 3, 1, 1)
850+
851+
with pytest.raises(ValueError, match=msg):
852+
at.alloc(x, 3, 1, 6)
853+
838854

839855
def test_infer_shape():
840856
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):

0 commit comments

Comments
 (0)