11
11
from numpy .lib .mixins import NDArrayOperatorsMixin
12
12
13
13
from .slicing import normalize_index
14
- from .utils import _zero_of_dtype , isscalar
14
+ from .utils import _zero_of_dtype , isscalar , PositinalArgumentPartial
15
15
from .sparse_array import SparseArray
16
16
from .compatibility import int , zip_longest , range , zip
17
17
@@ -1895,6 +1895,24 @@ def tril(x, k=0):
1895
1895
1896
1896
1897
1897
def _nary_match (* arrays ):
1898
+ """
1899
+ Matches coordinates from N different 1-D arrays. Equivalent to
1900
+ an SQL outer join.
1901
+
1902
+ Parameters
1903
+ ----------
1904
+ arrays : tuple[numpy.ndarray]
1905
+ Input arrays to match, must be sorted.
1906
+
1907
+ Returns
1908
+ -------
1909
+ matched : numpy.ndarray
1910
+ The overall matched array.
1911
+
1912
+ matched_idx : list[numpy.ndarray]
1913
+ The indices for the matched coordinates in each array.
1914
+ """
1915
+
1898
1916
matched = arrays [0 ]
1899
1917
matched_idx = [np .arange (arrays [0 ].shape [0 ],
1900
1918
dtype = np .min_scalar_type (arrays [0 ].shape [0 ] - 1 ))]
@@ -2022,28 +2040,25 @@ def elemwise(func, *args, **kwargs):
2022
2040
args = list (args )
2023
2041
posargs = []
2024
2042
pos = []
2025
- for i in range (len (args )):
2026
- if isinstance (args [i ], scipy .sparse .spmatrix ):
2027
- args [i ] = COO .from_scipy_sparse (args [i ])
2028
-
2029
- if isscalar (args [i ]) or (isinstance (args [i ], np .ndarray )
2030
- and not args [i ].shape ):
2043
+ for i , arg in enumerate (args ):
2044
+ if isinstance (arg , scipy .sparse .spmatrix ):
2045
+ args [i ] = COO .from_scipy_sparse (arg )
2046
+ elif isscalar (arg ) or (isinstance (arg , np .ndarray )
2047
+ and not arg .shape ):
2031
2048
# Faster and more reliable to pass ()-shaped ndarrays as scalars.
2032
- if isinstance (args [ i ] , np .ndarray ):
2033
- args [i ] = args [ i ] [()]
2049
+ if isinstance (arg , np .ndarray ):
2050
+ args [i ] = arg [()]
2034
2051
2035
- # The -scalars factor is there because we need to account for already
2036
- # added scalars in the function.
2037
2052
pos .append (i )
2038
2053
posargs .append (args [i ])
2039
- elif isinstance (args [ i ] , SparseArray ) and not isinstance (args [ i ] , COO ):
2040
- args [i ] = COO (args [ i ] )
2041
- elif not isinstance (args [ i ] , COO ):
2054
+ elif isinstance (arg , SparseArray ) and not isinstance (arg , COO ):
2055
+ args [i ] = COO (arg )
2056
+ elif not isinstance (arg , COO ):
2042
2057
raise ValueError ("Performing this operation would produce "
2043
2058
"a dense result: %s" % str (func ))
2044
2059
2045
2060
# Filter out scalars as they are 'baked' into the function.
2046
- func = _posarg_partial (func , pos , posargs )
2061
+ func = PositinalArgumentPartial (func , pos , posargs )
2047
2062
args = list (filter (lambda arg : not isscalar (arg ), args ))
2048
2063
2049
2064
if len (args ) == 0 :
@@ -2124,14 +2139,14 @@ def _match_coo(*args):
2124
2139
2125
2140
Parameters
2126
2141
----------
2127
- args : tuple [COO]
2142
+ args : Tuple [COO]
2128
2143
The input :obj:`COO` arrays.
2129
2144
2130
2145
Returns
2131
2146
-------
2132
- matched_idx : list [ndarray]
2147
+ matched_idx : List [ndarray]
2133
2148
The indices of matched elements in the original arrays.
2134
- matched_arrays : list [COO]
2149
+ matched_arrays : List [COO]
2135
2150
The expanded, matched :obj:`COO` objects.
2136
2151
"""
2137
2152
# If there's only a single input, return as-is.
@@ -2166,7 +2181,7 @@ def _match_coo(*args):
2166
2181
matched_coords = [c [:, idx ] for c , idx in zip (coords , matched_indices )]
2167
2182
2168
2183
# Add the matched part.
2169
- matched_coords = _get_nary_matching_coords (matched_coords , broadcast_params , matched_shape )
2184
+ matched_coords = _get_matching_coords (matched_coords , broadcast_params , matched_shape )
2170
2185
matched_data = [d [idx ] for d , idx in zip (data , matched_indices )]
2171
2186
2172
2187
matched_indices = [sidx [midx ] for sidx , midx in zip (sorted_indices , matched_indices )]
@@ -2206,7 +2221,7 @@ def _unmatch_coo(func, args, mask, **kwargs):
2206
2221
posargs = [_zero_of_dtype (arg .dtype )[()] for arg , m in zip (args , mask ) if not m ]
2207
2222
result_shape = _get_nary_broadcast_shape (* [arg .shape for arg in args ])
2208
2223
2209
- partial = _posarg_partial (func , pos , posargs )
2224
+ partial = PositinalArgumentPartial (func , pos , posargs )
2210
2225
matched_func = partial (* [a .data for a in matched_arrays ], ** kwargs )
2211
2226
2212
2227
unmatched_mask = matched_func != _zero_of_dtype (matched_func .dtype )
@@ -2468,7 +2483,7 @@ def _elemwise_unary(func, self, *args, **kwargs):
2468
2483
sorted = self .sorted )
2469
2484
2470
2485
2471
- def _get_nary_matching_coords (coords , params , shape ):
2486
+ def _get_matching_coords (coords , params , shape ):
2472
2487
"""
2473
2488
Get the matching coords across a number of broadcast operands.
2474
2489
@@ -2481,7 +2496,7 @@ def _get_nary_matching_coords(coords, params, shape):
2481
2496
Returns
2482
2497
-------
2483
2498
numpy.ndarray
2484
- The broacasted coordinates.
2499
+ The broacasted coordinates
2485
2500
"""
2486
2501
matching_coords = []
2487
2502
dims = np .zeros (len (params ), dtype = np .uint8 )
@@ -2516,35 +2531,3 @@ def _linear_loc(coords, shape, signed=False):
2516
2531
np .add (tmp , out , out = out )
2517
2532
strides *= d
2518
2533
return out
2519
-
2520
-
2521
- def _posarg_partial (func , pos , posargs ):
2522
- if not isinstance (pos , Iterable ):
2523
- pos = (pos ,)
2524
- posargs = (posargs ,)
2525
-
2526
- n_partial_args = len (pos )
2527
-
2528
- class Partial (object ):
2529
- def __call__ (self , * args , ** kwargs ):
2530
- j = 0
2531
- totargs = []
2532
-
2533
- for i in range (len (args ) + n_partial_args ):
2534
- if j >= n_partial_args or i != pos [j ]:
2535
- totargs .append (args [i - j ])
2536
- else :
2537
- totargs .append (posargs [j ])
2538
- j += 1
2539
-
2540
- return func (* totargs , ** kwargs )
2541
-
2542
- def __str__ (self ):
2543
- return str (func )
2544
-
2545
- def __repr__ (self ):
2546
- return repr (func )
2547
-
2548
- __doc__ = func .__doc__
2549
-
2550
- return Partial ()
0 commit comments