diff --git a/sparse/_coo/indexing.py b/sparse/_coo/indexing.py index 0ceaf7e5..98b0ca5a 100644 --- a/sparse/_coo/indexing.py +++ b/sparse/_coo/indexing.py @@ -30,6 +30,7 @@ def getitem(x, index): from .core import COO # If string, this is an index into an np.void + # Custom dtype. if isinstance(index, str): data = x.data[index] @@ -86,6 +87,7 @@ def getitem(x, index): i = 0 sorted = adv_idx is None or adv_idx.pos == 0 + adv_idx_added = False for ind in index: # Nothing is added to shape or coords if the index is an integer. if isinstance(ind, Integral): @@ -100,8 +102,10 @@ def getitem(x, index): sorted = False # Add the index and shape for the advanced index. elif isinstance(ind, np.ndarray): - shape.append(adv_idx.length) - coords.append(adv_idx.idx) + if not adv_idx_added: + shape.append(adv_idx.length) + coords.append(adv_idx.idx) + adv_idx_added = True i += 1 # Add a dimension for None. elif ind is None: @@ -141,20 +145,37 @@ def _mask(coords, indices, shape): if len(adv_idx) != 0: if len(adv_idx) != 1: - raise IndexError( - "Only indices with at most one iterable index are supported." + + # Ensure if multiple advanced indices are passed, all are of the same length + # Also check each advanced index to ensure each is only a one-dimensional iterable + adv_ix_len = len(adv_idx[0]) + for ai in adv_idx: + if len(ai) != adv_ix_len: + raise IndexError( + "shape mismatch: indexing arrays could not be broadcast together. Ensure all indexing arrays are of the same length." + ) + if ai.ndim != 1: + raise IndexError("Only one-dimensional iterable indices supported.") + + mask, aidxs = _compute_multi_axis_multi_mask( + coords, + _ind_ar_from_indices(indices), + np.array(adv_idx, dtype=np.intp), + np.array(adv_idx_pos, dtype=np.intp), ) + return mask, _AdvIdxInfo(aidxs, adv_idx_pos, adv_ix_len) - adv_idx = adv_idx[0] - adv_idx_pos = adv_idx_pos[0] + else: + adv_idx = adv_idx[0] + adv_idx_pos = adv_idx_pos[0] - if adv_idx.ndim != 1: - raise IndexError("Only one-dimensional iterable indices supported.") + if adv_idx.ndim != 1: + raise IndexError("Only one-dimensional iterable indices supported.") - mask, aidxs = _compute_multi_mask( - coords, _ind_ar_from_indices(indices), adv_idx, adv_idx_pos - ) - return mask, _AdvIdxInfo(aidxs, adv_idx_pos, len(adv_idx)) + mask, aidxs = _compute_multi_mask( + coords, _ind_ar_from_indices(indices), adv_idx, adv_idx_pos + ) + return mask, _AdvIdxInfo(aidxs, adv_idx_pos, len(adv_idx)) mask, is_slice = _compute_mask(coords, _ind_ar_from_indices(indices)) @@ -276,6 +297,68 @@ def _separate_adv_indices(indices): return new_idx, adv_idx, adv_idx_pos +@numba.jit(nopython=True, nogil=True) +def _compute_multi_axis_multi_mask( + coords, indices, adv_idx, adv_idx_pos +): # pragma: no cover + """ + Computes a mask with the advanced index, and also returns the advanced index + dimension. + + Parameters + ---------- + coords : np.ndarray + Coordinates of the input array. + indices : np.ndarray + The indices in slice format. + adv_idx : np.ndarray + List of advanced indices. + adv_idx_pos : np.ndarray + The position of the advanced indices. + + Returns + ------- + mask : np.ndarray + The mask. + aidxs : np.ndarray + The advanced array index. + """ + n_adv_idx = len(adv_idx_pos) + mask = numba.typed.List.empty_list(numba.types.intp) + a_indices = numba.typed.List.empty_list(numba.types.intp) + full_idx = np.empty((len(indices) + len(adv_idx_pos), 3), dtype=np.intp) + + # Get location of non-advanced indices + if len(indices) != 0: + ixx = 0 + for ix in range(coords.shape[0]): + isin = False + for ax in adv_idx_pos: + if ix == ax: + isin = True + break + if not isin: + full_idx[ix] = indices[ixx] + ixx += 1 + + for i in range(len(adv_idx[0])): + for ii in range(n_adv_idx): + full_idx[adv_idx_pos[ii]] = [adv_idx[ii][i], adv_idx[ii][i] + 1, 1] + + partial_mask, is_slice = _compute_mask(coords, full_idx) + if is_slice: + slice_mask = numba.typed.List.empty_list(numba.types.intp) + for j in range(partial_mask[0], partial_mask[1]): + slice_mask.append(j) + partial_mask = array_from_list_intp(slice_mask) + + for j in range(len(partial_mask)): + mask.append(partial_mask[j]) + a_indices.append(i) + + return array_from_list_intp(mask), array_from_list_intp(a_indices) + + @numba.jit(nopython=True, nogil=True) def _compute_multi_mask(coords, indices, adv_idx, adv_idx_pos): # pragma: no cover """ @@ -288,9 +371,9 @@ def _compute_multi_mask(coords, indices, adv_idx, adv_idx_pos): # pragma: no co Coordinates of the input array. indices : np.ndarray The indices in slice format. - adv_idx : int + adv_idx : list(int) The advanced index. - adv_idx_pos : int + adv_idx_pos : list(int) The position of the advanced index. Returns diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index 48f8f52a..42042ea0 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -1264,6 +1264,10 @@ def test_gt(): (1, Ellipsis, None), (1, 1, 1, Ellipsis), (Ellipsis, 1, None), + # With multi-axis advanced indexing + ([0, 1],) * 2, + ([0, 1], [0, 2]), + ([0, 0, 0], [0, 1, 2], [1, 2, 1]), # Pathological - Slices larger than array (slice(None, 1000)), (slice(None), slice(None, 1000)), @@ -1336,7 +1340,6 @@ def test_custom_dtype_slicing(): 0.5, [0.5], {"potato": "kartoffel"}, - ([0, 1],) * 2, ([[0, 1]],), ], )