@@ -2217,15 +2217,14 @@ def make_node(self, axis, *tensors):
2217
2217
# except for the axis dimension.
2218
2218
# Initialize bcastable all false, and then fill in some trues with
2219
2219
# the loops.
2220
- ndim = tensors [0 ].type .ndim
2221
- out_shape = [None ] * ndim
2222
2220
2223
2221
if not isinstance (axis , int ):
2224
2222
try :
2225
2223
axis = int (get_scalar_constant_value (axis ))
2226
2224
except NotScalarConstantError :
2227
2225
pass
2228
2226
2227
+ ndim = tensors [0 ].type .ndim
2229
2228
if isinstance (axis , int ):
2230
2229
# Basically, broadcastable -> length 1, but the
2231
2230
# converse does not hold. So we permit e.g. T/F/T
@@ -2241,30 +2240,55 @@ def make_node(self, axis, *tensors):
2241
2240
)
2242
2241
if axis < 0 :
2243
2242
axis += ndim
2244
-
2245
- for x in tensors :
2246
- for current_axis , s in enumerate (x .type .shape ):
2247
- # Constant negative axis can no longer be negative at
2248
- # this point. It safe to compare this way.
2249
- if current_axis == axis :
2250
- continue
2251
- if s == 1 :
2252
- out_shape [current_axis ] = 1
2253
- try :
2254
- out_shape [axis ] = None
2255
- except IndexError :
2243
+ if axis > ndim - 1 :
2256
2244
raise ValueError (
2257
2245
f"Axis value { axis } is out of range for the given input dimensions"
2258
2246
)
2247
+ # NOTE: Constant negative axis can no longer be negative at this point.
2248
+
2249
+ in_shapes = [x .type .shape for x in tensors ]
2250
+ in_ndims = [len (s ) for s in in_shapes ]
2251
+ if set (in_ndims ) != {ndim }:
2252
+ raise TypeError (
2253
+ "Only tensors with the same number of dimensions can be joined."
2254
+ f" Input ndims were: { in_ndims } ."
2255
+ )
2256
+
2257
+ # Determine output shapes from a matrix of input shapes
2258
+ in_shapes = np .array (in_shapes )
2259
+ out_shape = [None ] * ndim
2260
+ for d in range (ndim ):
2261
+ ins = in_shapes [:, d ]
2262
+ if d == axis :
2263
+ # Any unknown size along the axis means we can't sum
2264
+ if None in ins :
2265
+ out_shape [d ] = None
2266
+ else :
2267
+ out_shape [d ] = sum (ins )
2268
+ else :
2269
+ inset = set (in_shapes [:, d ])
2270
+ # Other dims must match exactly,
2271
+ # or if a mix of None and ? the output will be ?
2272
+ # otherwise the input shapes are incompatible.
2273
+ if len (inset ) == 1 :
2274
+ (out_shape [d ],) = inset
2275
+ elif len (inset - {None }) == 1 :
2276
+ (out_shape [d ],) = inset - {None }
2277
+ else :
2278
+ raise ValueError (
2279
+ f"all input array dimensions other than the specified `axis` ({ axis } )"
2280
+ " must match exactly, or be unknown (None),"
2281
+ f" but along dimension { d } , the inputs shapes are incompatible: { ins } "
2282
+ )
2259
2283
else :
2260
2284
# When the axis may vary, no dimension can be guaranteed to be
2261
2285
# broadcastable.
2262
2286
out_shape = [None ] * tensors [0 ].type .ndim
2263
2287
2264
- if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2265
- raise TypeError (
2266
- "Only tensors with the same number of dimensions can be joined"
2267
- )
2288
+ if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2289
+ raise TypeError (
2290
+ "Only tensors with the same number of dimensions can be joined"
2291
+ )
2268
2292
2269
2293
inputs = [as_tensor_variable (axis )] + list (tensors )
2270
2294
0 commit comments