@@ -2439,27 +2439,7 @@ class Join(COp):
2439
2439
"""
2440
2440
2441
2441
check_input = False
2442
- __props__ = ("view" ,)
2443
-
2444
- def __init__ (self , view = - 1 ):
2445
- self .view = view
2446
- if view != - 1 :
2447
- # since the first input is always the axis, the tensors
2448
- # start from index 1.
2449
- self .view_map = {0 : [1 + view ]}
2450
-
2451
- def __str__ (self ):
2452
- if self .view == - 1 :
2453
- return self .__class__ .__name__
2454
- else :
2455
- classname = self .__class__ .__name__
2456
- args = ", " .join (f"{ p } ={ getattr (self , p )!r} " for p in self .__props__ )
2457
- return f"{ classname } {{{ args } }}"
2458
-
2459
- def __setstate__ (self , d ):
2460
- self .__dict__ .update (d )
2461
- if not hasattr (self , "view" ):
2462
- self .view = - 1
2442
+ __props__ = ()
2463
2443
2464
2444
def make_node (self , axis , * tensors ):
2465
2445
"""
@@ -2476,74 +2456,62 @@ def make_node(self, axis, *tensors):
2476
2456
if not tensors :
2477
2457
raise ValueError ("Cannot join an empty list of tensors" )
2478
2458
2459
+ axis = as_tensor_variable (axis )
2460
+ if axis .type .dtype not in int_dtypes :
2461
+ raise TypeError (f"Axis { axis } must be an integer type." )
2462
+ if axis .type .ndim > 0 :
2463
+ raise TypeError (f"Axis { axis } must be 0-d." )
2464
+
2479
2465
tensors = [as_tensor_variable (x ) for x in tensors ]
2480
- out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
2481
2466
2482
- if not builtins .all (targs .type .ndim for targs in tensors ):
2467
+ if not builtins .all (targs .type .ndim > 0 for targs in tensors ):
2483
2468
raise TypeError (
2484
2469
"Join cannot handle arguments of dimension 0."
2485
- " Use `stack` to join scalar values."
2470
+ " Use `stack` to join scalar values and/or increase rank of scalars ."
2486
2471
)
2487
2472
2488
2473
if len (tensors ) == 1 :
2489
2474
out_shape = tensors [0 ].type .shape
2490
2475
else :
2491
- # When the axis is fixed, a dimension should be
2492
- # broadcastable if at least one of the inputs is
2493
- # broadcastable on that dimension (see justification below),
2494
- # except for the axis dimension.
2495
- # Initialize bcastable all false, and then fill in some trues with
2496
- # the loops.
2497
-
2498
- if not isinstance (axis , int ):
2499
- try :
2500
- axis = int (get_scalar_constant_value (axis ))
2501
- except NotScalarConstantError :
2502
- pass
2503
-
2504
2476
ndim = tensors [0 ].type .ndim
2505
- if isinstance (axis , int ):
2506
- # Basically, broadcastable -> length 1, but the
2507
- # converse does not hold. So we permit e.g. T/F/T
2508
- # joins, and if they fail at runtime they fail, but if
2509
- # they don't then it means that the argument where
2510
- # that broadcastable flag was False had length 1 along
2511
- # this dimension, and therefore this dimension should
2512
- # be broadcastable for the output.
2513
-
2514
- if axis < - ndim :
2515
- raise IndexError (
2516
- f"Axis value { axis } is out of range for the given input dimensions"
2517
- )
2518
- if axis < 0 :
2519
- axis += ndim
2520
- if axis > ndim - 1 :
2521
- raise ValueError (
2522
- f"Axis value { axis } is out of range for the given input dimensions"
2523
- )
2524
- # NOTE: Constant negative axis can no longer be negative at this point.
2525
-
2526
- in_shapes = [x .type .shape for x in tensors ]
2527
- in_ndims = [len (s ) for s in in_shapes ]
2528
- if set (in_ndims ) != {ndim }:
2529
- raise TypeError (
2530
- "Only tensors with the same number of dimensions can be joined."
2531
- f" Input ndims were: { in_ndims } ."
2532
- )
2477
+
2478
+ if not builtins .all (x .ndim == ndim for x in tensors ):
2479
+ raise TypeError (
2480
+ "Only tensors with the same number of dimensions can be joined"
2481
+ )
2482
+
2483
+ try :
2484
+ # Note: This is dubious, if a user passed a constant we should propagate it to the inputs
2485
+ # Not override it.
2486
+ static_axis = int (get_scalar_constant_value (axis ))
2487
+ except NotScalarConstantError :
2488
+ static_axis = None
2489
+
2490
+ if static_axis is None :
2491
+ # When axis isn't static, we can't canclude anything about output dimension
2492
+ # (unless we had some degenerate zero arrays) that can be removed during rewrites.
2493
+ # We could also raise errors if any dimensions are pairwise inconsistent across all the axes
2494
+ # As no matter the join it would be invalid.
2495
+ # However, dynamic axis is so rare that is not worth the trouble
2496
+ out_shape = [None ] * ndim
2497
+
2498
+ else : # We know the axis statically
2499
+ static_axis = normalize_axis_index (static_axis , ndim )
2500
+ static_shapes = [x .type .shape for x in tensors ]
2533
2501
2534
2502
# Determine output shapes from a matrix of input shapes
2535
- in_shapes = np .array (in_shapes )
2503
+ static_shapes = np .array (static_shapes )
2536
2504
out_shape = [None ] * ndim
2537
2505
for d in range (ndim ):
2538
- ins = in_shapes [:, d ]
2539
- if d == axis :
2540
- # Any unknown size along the axis means we can't sum
2506
+ ins = static_shapes [:, d ]
2507
+ if d == static_axis :
2508
+ # Any unknown size along the axis means we can't infer it
2541
2509
if None in ins :
2542
2510
out_shape [d ] = None
2543
2511
else :
2544
2512
out_shape [d ] = sum (ins )
2545
2513
else :
2546
- inset = set (in_shapes [:, d ])
2514
+ inset = set (static_shapes [:, d ])
2547
2515
# Other dims must match exactly,
2548
2516
# or if a mix of None and ? the output will be ?
2549
2517
# otherwise the input shapes are incompatible.
@@ -2553,100 +2521,71 @@ def make_node(self, axis, *tensors):
2553
2521
(out_shape [d ],) = inset - {None }
2554
2522
else :
2555
2523
raise ValueError (
2556
- f"all input array dimensions other than the specified `axis` ({ axis } )"
2524
+ f"all input array dimensions other than the specified `axis` ({ static_axis } )"
2557
2525
" must match exactly, or be unknown (None),"
2558
2526
f" but along dimension { d } , the inputs shapes are incompatible: { ins } "
2559
2527
)
2560
- else :
2561
- # When the axis may vary, no dimension can be guaranteed to be
2562
- # broadcastable.
2563
- out_shape = [None ] * tensors [0 ].type .ndim
2564
-
2565
- if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2566
- raise TypeError (
2567
- "Only tensors with the same number of dimensions can be joined"
2568
- )
2569
-
2570
- inputs = [as_tensor_variable (axis ), * tensors ]
2571
-
2572
- if inputs [0 ].type .dtype not in int_dtypes :
2573
- raise TypeError (f"Axis value { inputs [0 ]} must be an integer type" )
2574
2528
2529
+ inputs = [axis , * tensors ]
2530
+ out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
2575
2531
return Apply (self , inputs , [tensor (dtype = out_dtype , shape = out_shape )])
2576
2532
2577
- def perform (self , node , axis_and_tensors , out_ ):
2578
- (out ,) = out_
2579
- view = self .view
2580
- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2581
- # we check these tensors for being empty.
2582
- if (view != - 1 ) and all (
2583
- tensor .shape [axis ] == 0 for tensor in tens [0 :view ] + tens [view + 1 :]
2584
- ):
2585
- out [0 ] = tens [view ]
2586
-
2587
- else :
2588
- ndim = tens [0 ].ndim
2589
- if axis < - ndim :
2590
- raise IndexError (
2591
- f"Join axis { int (axis )} out of bounds [0, { int (ndim )} )"
2592
- )
2593
-
2594
- out [0 ] = np .asarray (
2595
- np .concatenate (tens , axis = axis ), dtype = node .outputs [0 ].type .dtype
2596
- )
2533
+ def perform (self , node , inputs , output_storage ):
2534
+ axis , * arrays = inputs
2535
+ output_storage [0 ][0 ] = np .concatenate (
2536
+ arrays , axis = axis , dtype = node .outputs [0 ].type .dtype
2537
+ )
2597
2538
2598
2539
def c_code_cache_version (self ):
2599
- return (5 ,)
2540
+ return (6 ,)
2600
2541
2601
2542
def c_code (self , node , name , inputs , outputs , sub ):
2602
- axis , tens = inputs [0 ], inputs [1 :]
2603
- view = self .view
2604
- non_empty_tensor = tens [view ]
2605
- input_1 = tens [0 ]
2606
- l = len (tens )
2607
- (out ,) = outputs
2543
+ axis , * arrays = inputs
2544
+ [out ] = outputs
2545
+ n = len (arrays )
2546
+ ndim = node .outputs [0 ].type .ndim
2608
2547
fail = sub ["fail" ]
2609
- adtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2610
2548
2611
- copy_to_list = (
2612
- f"""Py_INCREF({ inp } ); PyList_SetItem(list, { i } , (PyObject*){ inp } );"""
2613
- for i , inp in enumerate (tens )
2614
- )
2549
+ # Most times axis is constant, inline it
2550
+ # This is safe to do because the hash of the c_code includes the constant signature
2551
+ if isinstance (node .inputs [0 ], Constant ):
2552
+ static_axis = int (node .inputs [0 ].data )
2553
+ static_axis = normalize_axis_index (static_axis , ndim )
2554
+ axis_def = f"{ static_axis } ;"
2555
+ axis_check = ""
2556
+ else :
2557
+ axis_dtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2558
+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2559
+ axis_check = f"""
2560
+ if (axis < 0){{
2561
+ axis = { ndim } + axis;
2562
+ }}
2563
+ if (axis >= { ndim } || axis < 0) {{
2564
+ PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2565
+ { fail }
2566
+ }}
2567
+ """
2615
2568
2616
- copy_inputs_to_list = "\n " .join (copy_to_list )
2617
- n = len (tens )
2569
+ copy_arrays_to_tuple = "\n " .join (
2570
+ (
2571
+ f"""Py_INCREF({ array } ); PyTuple_SetItem(arrays_tuple, { i } , (PyObject*){ array } );"""
2572
+ for i , array in enumerate (arrays )
2573
+ )
2574
+ )
2618
2575
2619
2576
code = f"""
2620
- int axis = (({ adtype } *)PyArray_DATA({ axis } ))[0];
2621
- PyObject* list = PyList_New({ l } );
2622
- { copy_inputs_to_list }
2623
- int tensors_lens_sum;
2624
- if({ view } != -1) {{
2625
- tensors_lens_sum = 0;
2626
-
2627
- for(int i=0; i < { n } ; i++){{
2628
- tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2629
- }}
2630
- tensors_lens_sum -= PyArray_DIM({ non_empty_tensor } , axis);
2631
- }}
2632
- if({ view } != -1 && tensors_lens_sum == 0) {{
2633
- Py_XDECREF({ out } );
2634
- Py_INCREF({ non_empty_tensor } );
2635
- { out } = { non_empty_tensor } ;
2636
- }}else{{
2637
- //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2638
- int ndim = PyArray_NDIM({ input_1 } );
2639
- if( axis < -ndim ){{
2640
- PyErr_Format(PyExc_IndexError,
2641
- "Join axis %d out of bounds [0, %d)", axis, ndim);
2642
- { fail }
2643
- }}
2644
- Py_XDECREF({ out } );
2645
- { out } = (PyArrayObject *)PyArray_Concatenate(list, axis);
2646
- Py_DECREF(list);
2647
- if(!{ out } ){{
2648
- { fail }
2649
- }}
2577
+ int axis = { axis_def }
2578
+ PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2579
+
2580
+ { axis_check }
2581
+
2582
+ Py_XDECREF({ out } );
2583
+ PyObject* arrays_tuple = PyTuple_New({ n } );
2584
+ { copy_arrays_to_tuple }
2585
+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2586
+ Py_DECREF(arrays_tuple);
2587
+ if(!{ out } ){{
2588
+ { fail }
2650
2589
}}
2651
2590
"""
2652
2591
return code
@@ -2656,22 +2595,21 @@ def R_op(self, inputs, eval_points):
2656
2595
return [None ]
2657
2596
return self .make_node (inputs [0 ], * eval_points [1 :]).outputs
2658
2597
2659
- def grad (self , axis_and_tensors , grads ):
2598
+ def L_op (self , inputs , outputs , grads ):
2660
2599
"""The gradient wrt a join op is a `Split`, used to partition
2661
2600
the gradient along the `axis` which was used for joining.
2662
2601
"""
2663
- (gz ,) = grads
2664
- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2602
+ [gz ] = grads
2603
+ [out ] = outputs
2604
+ axis , * tensors = inputs
2665
2605
2666
2606
rval = [grad_undefined (self , 0 , axis )]
2667
-
2668
- dtypes = [as_tensor_variable (x ).type .dtype for x in tens ]
2669
- out_dtype = ps .upcast (* dtypes )
2607
+ out_dtype = out .type .dtype
2670
2608
2671
2609
if "float" in out_dtype or "complex" in out_dtype :
2672
2610
# assume that this is differentiable
2673
- split = Split ( len ( tens ) )
2674
- split_gz = split (gz , axis , stack ([ shape ( x )[ axis ] for x in tens ]) )
2611
+ split_sizes = stack ([ shape ( x )[ axis ] for x in tensors ] )
2612
+ split_gz = split (gz , split_sizes , n_splits = len ( tensors ), axis = axis )
2675
2613
# If there is only one split, it might not be in a list.
2676
2614
if not isinstance (split_gz , list ):
2677
2615
split_gz = [split_gz ]
@@ -2684,13 +2622,12 @@ def grad(self, axis_and_tensors, grads):
2684
2622
else specify_broadcastable (
2685
2623
g , * (ax for (ax , s ) in enumerate (t .type .shape ) if s == 1 )
2686
2624
)
2687
- for t , g in zip (tens , split_gz , strict = True )
2625
+ for t , g in zip (tensors , split_gz , strict = True )
2688
2626
]
2689
2627
rval = rval + split_gz
2690
2628
else :
2691
- # the output has integer type, so the gradient through it
2692
- # is 0
2693
- rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tens ]
2629
+ # the output has integer type, so the gradient through it is 0
2630
+ rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tensors ]
2694
2631
2695
2632
return rval
2696
2633
@@ -2710,7 +2647,8 @@ def infer_shape(self, fgraph, node, ishapes):
2710
2647
# An axis < -n_dim or >= ndim would be invalid, but this is
2711
2648
# not checked here. A `CheckAndRaise` `Op` would be a way of
2712
2649
# addressing that, but it may disrupt optimizations.
2713
- join_dim = switch (ge (node .inputs [0 ], 0 ), node .inputs [0 ], node .inputs [0 ] + n_dim )
2650
+ axis = node .inputs [0 ]
2651
+ join_dim = switch (ge (axis , 0 ), axis , axis + n_dim )
2714
2652
out_shapes = []
2715
2653
for dim in range (n_dim ):
2716
2654
# we have to deal with 2 possible cases in here :
@@ -2733,7 +2671,7 @@ def infer_shape(self, fgraph, node, ishapes):
2733
2671
return [tuple (out_shapes )]
2734
2672
2735
2673
2736
- join_ = Join ()
2674
+ _join = Join ()
2737
2675
pprint .assign (Join , printing .FunctionPrinter (["join" ]))
2738
2676
2739
2677
@@ -2776,7 +2714,7 @@ def join(axis, *tensors_list):
2776
2714
if len (tensors_list ) == 1 :
2777
2715
return tensors_list [0 ]
2778
2716
else :
2779
- return join_ (axis , * tensors_list )
2717
+ return _join (axis , * tensors_list )
2780
2718
2781
2719
2782
2720
@_vectorize_node .register (Join )
0 commit comments