Skip to content

Commit 2d8ea78

Browse files
michaelosthegetwiecki
authored andcommitted
Cleanup Join.make_node to infer static shapes
Closes #163
1 parent 9e4c0e4 commit 2d8ea78

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed

pytensor/tensor/basic.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,15 +2217,14 @@ def make_node(self, axis, *tensors):
22172217
# except for the axis dimension.
22182218
# Initialize bcastable all false, and then fill in some trues with
22192219
# the loops.
2220-
ndim = tensors[0].type.ndim
2221-
out_shape = [None] * ndim
22222220

22232221
if not isinstance(axis, int):
22242222
try:
22252223
axis = int(get_scalar_constant_value(axis))
22262224
except NotScalarConstantError:
22272225
pass
22282226

2227+
ndim = tensors[0].type.ndim
22292228
if isinstance(axis, int):
22302229
# Basically, broadcastable -> length 1, but the
22312230
# converse does not hold. So we permit e.g. T/F/T
@@ -2241,30 +2240,55 @@ def make_node(self, axis, *tensors):
22412240
)
22422241
if axis < 0:
22432242
axis += ndim
2244-
2245-
for x in tensors:
2246-
for current_axis, s in enumerate(x.type.shape):
2247-
# Constant negative axis can no longer be negative at
2248-
# this point. It safe to compare this way.
2249-
if current_axis == axis:
2250-
continue
2251-
if s == 1:
2252-
out_shape[current_axis] = 1
2253-
try:
2254-
out_shape[axis] = None
2255-
except IndexError:
2243+
if axis > ndim - 1:
22562244
raise ValueError(
22572245
f"Axis value {axis} is out of range for the given input dimensions"
22582246
)
2247+
# NOTE: Constant negative axis can no longer be negative at this point.
2248+
2249+
in_shapes = [x.type.shape for x in tensors]
2250+
in_ndims = [len(s) for s in in_shapes]
2251+
if set(in_ndims) != {ndim}:
2252+
raise TypeError(
2253+
"Only tensors with the same number of dimensions can be joined."
2254+
f" Input ndims were: {in_ndims}."
2255+
)
2256+
2257+
# Determine output shapes from a matrix of input shapes
2258+
in_shapes = np.array(in_shapes)
2259+
out_shape = [None] * ndim
2260+
for d in range(ndim):
2261+
ins = in_shapes[:, d]
2262+
if d == axis:
2263+
# Any unknown size along the axis means we can't sum
2264+
if None in ins:
2265+
out_shape[d] = None
2266+
else:
2267+
out_shape[d] = sum(ins)
2268+
else:
2269+
inset = set(in_shapes[:, d])
2270+
# Other dims must match exactly,
2271+
# or if a mix of None and ? the output will be ?
2272+
# otherwise the input shapes are incompatible.
2273+
if len(inset) == 1:
2274+
(out_shape[d],) = inset
2275+
elif len(inset - {None}) == 1:
2276+
(out_shape[d],) = inset - {None}
2277+
else:
2278+
raise ValueError(
2279+
f"all input array dimensions other than the specified `axis` ({axis})"
2280+
" must match exactly, or be unknown (None),"
2281+
f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
2282+
)
22592283
else:
22602284
# When the axis may vary, no dimension can be guaranteed to be
22612285
# broadcastable.
22622286
out_shape = [None] * tensors[0].type.ndim
22632287

2264-
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
2265-
raise TypeError(
2266-
"Only tensors with the same number of dimensions can be joined"
2267-
)
2288+
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
2289+
raise TypeError(
2290+
"Only tensors with the same number of dimensions can be joined"
2291+
)
22682292

22692293
inputs = [as_tensor_variable(axis)] + list(tensors)
22702294

tests/tensor/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,21 @@ def test_mixed_ndim_error(self):
19091909
with pytest.raises(TypeError, match="same number of dimensions"):
19101910
self.join_op(0, v, m)
19111911

1912+
def test_static_shape_inference(self):
1913+
a = at.tensor(dtype="int8", shape=(2, 3))
1914+
b = at.tensor(dtype="int8", shape=(2, 5))
1915+
assert at.join(1, a, b).type.shape == (2, 8)
1916+
assert at.join(-1, a, b).type.shape == (2, 8)
1917+
1918+
# Check early informative errors from static shape info
1919+
with pytest.raises(ValueError, match="must match exactly"):
1920+
at.join(0, at.ones((2, 3)), at.ones((2, 5)))
1921+
1922+
# Check partial inference
1923+
d = at.tensor(dtype="int8", shape=(2, None))
1924+
assert at.join(1, a, b, d).type.shape == (2, None)
1925+
return
1926+
19121927
def test_split_0elem(self):
19131928
rng = np.random.default_rng(seed=utt.fetch_seed())
19141929
m = self.shared(rng.random((4, 6)).astype(self.floatX))

0 commit comments

Comments
 (0)