@@ -2439,27 +2439,7 @@ class Join(COp):
2439
2439
"""
2440
2440
2441
2441
check_input = False
2442
- __props__ = ("view" ,)
2443
-
2444
- def __init__ (self , view = - 1 ):
2445
- self .view = view
2446
- if view != - 1 :
2447
- # since the first input is always the axis, the tensors
2448
- # start from index 1.
2449
- self .view_map = {0 : [1 + view ]}
2450
-
2451
- def __str__ (self ):
2452
- if self .view == - 1 :
2453
- return self .__class__ .__name__
2454
- else :
2455
- classname = self .__class__ .__name__
2456
- args = ", " .join (f"{ p } ={ getattr (self , p )!r} " for p in self .__props__ )
2457
- return f"{ classname } {{{ args } }}"
2458
-
2459
- def __setstate__ (self , d ):
2460
- self .__dict__ .update (d )
2461
- if not hasattr (self , "view" ):
2462
- self .view = - 1
2442
+ __props__ = ()
2463
2443
2464
2444
def make_node (self , axis , * tensors ):
2465
2445
"""
@@ -2476,74 +2456,62 @@ def make_node(self, axis, *tensors):
2476
2456
if not tensors :
2477
2457
raise ValueError ("Cannot join an empty list of tensors" )
2478
2458
2459
+ axis = as_tensor_variable (axis )
2460
+ if axis .type .dtype not in int_dtypes :
2461
+ raise TypeError (f"Axis { axis } must be an integer type." )
2462
+ if axis .type .ndim > 0 :
2463
+ raise TypeError (f"Axis { axis } must be 0-d." )
2464
+
2479
2465
tensors = [as_tensor_variable (x ) for x in tensors ]
2480
- out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
2481
2466
2482
- if not builtins .all (targs .type .ndim for targs in tensors ):
2467
+ if not builtins .all (targs .type .ndim > 0 for targs in tensors ):
2483
2468
raise TypeError (
2484
2469
"Join cannot handle arguments of dimension 0."
2485
- " Use `stack` to join scalar values."
2470
+ " Use `stack` to join scalar values and/or increase rank of scalars ."
2486
2471
)
2487
2472
2488
2473
if len (tensors ) == 1 :
2489
2474
out_shape = tensors [0 ].type .shape
2490
2475
else :
2491
- # When the axis is fixed, a dimension should be
2492
- # broadcastable if at least one of the inputs is
2493
- # broadcastable on that dimension (see justification below),
2494
- # except for the axis dimension.
2495
- # Initialize bcastable all false, and then fill in some trues with
2496
- # the loops.
2497
-
2498
- if not isinstance (axis , int ):
2499
- try :
2500
- axis = int (get_scalar_constant_value (axis ))
2501
- except NotScalarConstantError :
2502
- pass
2503
-
2504
2476
ndim = tensors [0 ].type .ndim
2505
- if isinstance (axis , int ):
2506
- # Basically, broadcastable -> length 1, but the
2507
- # converse does not hold. So we permit e.g. T/F/T
2508
- # joins, and if they fail at runtime they fail, but if
2509
- # they don't then it means that the argument where
2510
- # that broadcastable flag was False had length 1 along
2511
- # this dimension, and therefore this dimension should
2512
- # be broadcastable for the output.
2513
-
2514
- if axis < - ndim :
2515
- raise IndexError (
2516
- f"Axis value { axis } is out of range for the given input dimensions"
2517
- )
2518
- if axis < 0 :
2519
- axis += ndim
2520
- if axis > ndim - 1 :
2521
- raise ValueError (
2522
- f"Axis value { axis } is out of range for the given input dimensions"
2523
- )
2524
- # NOTE: Constant negative axis can no longer be negative at this point.
2525
-
2526
- in_shapes = [x .type .shape for x in tensors ]
2527
- in_ndims = [len (s ) for s in in_shapes ]
2528
- if set (in_ndims ) != {ndim }:
2529
- raise TypeError (
2530
- "Only tensors with the same number of dimensions can be joined."
2531
- f" Input ndims were: { in_ndims } ."
2532
- )
2477
+
2478
+ if not builtins .all (x .ndim == ndim for x in tensors ):
2479
+ raise TypeError (
2480
+ "Only tensors with the same number of dimensions can be joined"
2481
+ )
2482
+
2483
+ try :
2484
+ # Note: This is dubious, if a user passed a constant we should propagate it to the inputs
2485
+ # Not override it.
2486
+ static_axis = int (get_scalar_constant_value (axis ))
2487
+ except NotScalarConstantError :
2488
+ static_axis = None
2489
+
2490
+ if static_axis is None :
2491
+ # When axis isn't static, we can't canclude anything about output dimension
2492
+ # (unless we had some degenerate zero arrays) that can be removed during rewrites.
2493
+ # We could also raise errors if any dimensions are pairwise inconsistent across all the axes
2494
+ # As no matter the join it would be invalid.
2495
+ # However, dynamic axis is so rare that is not worth the trouble
2496
+ out_shape = [None ] * ndim
2497
+
2498
+ else : # We know the axis statically
2499
+ static_axis = normalize_axis_index (static_axis , ndim )
2500
+ static_shapes = [x .type .shape for x in tensors ]
2533
2501
2534
2502
# Determine output shapes from a matrix of input shapes
2535
- in_shapes = np .array (in_shapes )
2503
+ static_shapes = np .array (static_shapes )
2536
2504
out_shape = [None ] * ndim
2537
2505
for d in range (ndim ):
2538
- ins = in_shapes [:, d ]
2539
- if d == axis :
2540
- # Any unknown size along the axis means we can't sum
2506
+ ins = static_shapes [:, d ]
2507
+ if d == static_axis :
2508
+ # Any unknown size along the axis means we can't infer it
2541
2509
if None in ins :
2542
2510
out_shape [d ] = None
2543
2511
else :
2544
2512
out_shape [d ] = sum (ins )
2545
2513
else :
2546
- inset = set (in_shapes [:, d ])
2514
+ inset = set (static_shapes [:, d ])
2547
2515
# Other dims must match exactly,
2548
2516
# or if a mix of None and ? the output will be ?
2549
2517
# otherwise the input shapes are incompatible.
@@ -2553,54 +2521,27 @@ def make_node(self, axis, *tensors):
2553
2521
(out_shape [d ],) = inset - {None }
2554
2522
else :
2555
2523
raise ValueError (
2556
- f"all input array dimensions other than the specified `axis` ({ axis } )"
2524
+ f"all input array dimensions other than the specified `axis` ({ static_axis } )"
2557
2525
" must match exactly, or be unknown (None),"
2558
2526
f" but along dimension { d } , the inputs shapes are incompatible: { ins } "
2559
2527
)
2560
- else :
2561
- # When the axis may vary, no dimension can be guaranteed to be
2562
- # broadcastable.
2563
- out_shape = [None ] * tensors [0 ].type .ndim
2564
-
2565
- if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2566
- raise TypeError (
2567
- "Only tensors with the same number of dimensions can be joined"
2568
- )
2569
-
2570
- inputs = [as_tensor_variable (axis ), * tensors ]
2571
-
2572
- if inputs [0 ].type .dtype not in int_dtypes :
2573
- raise TypeError (f"Axis value { inputs [0 ]} must be an integer type" )
2574
2528
2529
+ inputs = [axis , * tensors ]
2530
+ out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
2575
2531
return Apply (self , inputs , [tensor (dtype = out_dtype , shape = out_shape )])
2576
2532
2577
- def perform (self , node , axis_and_tensors , out_ ):
2578
- (out ,) = out_
2579
- view = self .view
2580
- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2581
- # we check these tensors for being empty.
2582
- if (view != - 1 ) and all (
2583
- tensor .shape [axis ] == 0 for tensor in tens [0 :view ] + tens [view + 1 :]
2584
- ):
2585
- out [0 ] = tens [view ]
2586
-
2587
- else :
2588
- ndim = tens [0 ].ndim
2589
- if axis < - ndim :
2590
- raise IndexError (
2591
- f"Join axis { int (axis )} out of bounds [0, { int (ndim )} )"
2592
- )
2593
-
2594
- out [0 ] = np .asarray (
2595
- np .concatenate (tens , axis = axis ), dtype = node .outputs [0 ].type .dtype
2596
- )
2533
+ def perform (self , node , inputs , output_storage ):
2534
+ axis , * arrays = inputs
2535
+ output_storage [0 ][0 ] = np .concatenate (
2536
+ arrays , axis = axis , dtype = node .outputs [0 ].type .dtype
2537
+ )
2597
2538
2598
2539
def c_code_cache_version (self ):
2599
2540
return (5 ,)
2600
2541
2601
2542
def c_code (self , node , name , inputs , outputs , sub ):
2602
2543
axis , tens = inputs [0 ], inputs [1 :]
2603
- view = self . view
2544
+ view = - 1
2604
2545
non_empty_tensor = tens [view ]
2605
2546
input_1 = tens [0 ]
2606
2547
l = len (tens )
@@ -2656,22 +2597,21 @@ def R_op(self, inputs, eval_points):
2656
2597
return [None ]
2657
2598
return self .make_node (inputs [0 ], * eval_points [1 :]).outputs
2658
2599
2659
- def grad (self , axis_and_tensors , grads ):
2600
+ def L_op (self , inputs , outputs , grads ):
2660
2601
"""The gradient wrt a join op is a `Split`, used to partition
2661
2602
the gradient along the `axis` which was used for joining.
2662
2603
"""
2663
- (gz ,) = grads
2664
- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2604
+ [gz ] = grads
2605
+ [out ] = outputs
2606
+ axis , * tensors = inputs
2665
2607
2666
2608
rval = [grad_undefined (self , 0 , axis )]
2667
-
2668
- dtypes = [as_tensor_variable (x ).type .dtype for x in tens ]
2669
- out_dtype = ps .upcast (* dtypes )
2609
+ out_dtype = out .type .dtype
2670
2610
2671
2611
if "float" in out_dtype or "complex" in out_dtype :
2672
2612
# assume that this is differentiable
2673
- split = Split ( len ( tens ) )
2674
- split_gz = split (gz , axis , stack ([ shape ( x )[ axis ] for x in tens ]) )
2613
+ split_sizes = stack ([ shape ( x )[ axis ] for x in tensors ] )
2614
+ split_gz = split (gz , split_sizes , n_splits = len ( tensors ), axis = axis )
2675
2615
# If there is only one split, it might not be in a list.
2676
2616
if not isinstance (split_gz , list ):
2677
2617
split_gz = [split_gz ]
@@ -2684,13 +2624,12 @@ def grad(self, axis_and_tensors, grads):
2684
2624
else specify_broadcastable (
2685
2625
g , * (ax for (ax , s ) in enumerate (t .type .shape ) if s == 1 )
2686
2626
)
2687
- for t , g in zip (tens , split_gz , strict = True )
2627
+ for t , g in zip (tensors , split_gz , strict = True )
2688
2628
]
2689
2629
rval = rval + split_gz
2690
2630
else :
2691
- # the output has integer type, so the gradient through it
2692
- # is 0
2693
- rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tens ]
2631
+ # the output has integer type, so the gradient through it is 0
2632
+ rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tensors ]
2694
2633
2695
2634
return rval
2696
2635
@@ -2710,7 +2649,8 @@ def infer_shape(self, fgraph, node, ishapes):
2710
2649
# An axis < -n_dim or >= ndim would be invalid, but this is
2711
2650
# not checked here. A `CheckAndRaise` `Op` would be a way of
2712
2651
# addressing that, but it may disrupt optimizations.
2713
- join_dim = switch (ge (node .inputs [0 ], 0 ), node .inputs [0 ], node .inputs [0 ] + n_dim )
2652
+ axis = node .inputs [0 ]
2653
+ join_dim = switch (ge (axis , 0 ), axis , axis + n_dim )
2714
2654
out_shapes = []
2715
2655
for dim in range (n_dim ):
2716
2656
# we have to deal with 2 possible cases in here :
@@ -2733,7 +2673,7 @@ def infer_shape(self, fgraph, node, ishapes):
2733
2673
return [tuple (out_shapes )]
2734
2674
2735
2675
2736
- join_ = Join ()
2676
+ _join = Join ()
2737
2677
pprint .assign (Join , printing .FunctionPrinter (["join" ]))
2738
2678
2739
2679
@@ -2776,7 +2716,7 @@ def join(axis, *tensors_list):
2776
2716
if len (tensors_list ) == 1 :
2777
2717
return tensors_list [0 ]
2778
2718
else :
2779
- return join_ (axis , * tensors_list )
2719
+ return _join (axis , * tensors_list )
2780
2720
2781
2721
2782
2722
@_vectorize_node .register (Join )
0 commit comments