|
1 |
| -import numbers |
2 | 1 | from collections import Iterable, defaultdict, deque
|
3 | 2 |
|
4 | 3 | import numpy as np
|
5 | 4 | import scipy.sparse
|
6 | 5 | from numpy.lib.mixins import NDArrayOperatorsMixin
|
7 | 6 |
|
8 | 7 | from .common import dot
|
| 8 | +from .indexing import getitem |
9 | 9 | from .umath import elemwise, broadcast_to
|
10 | 10 | from ..compatibility import int, range
|
11 |
| -from ..slicing import normalize_index |
12 | 11 | from ..sparse_array import SparseArray
|
13 | 12 | from ..utils import _zero_of_dtype
|
14 | 13 |
|
@@ -487,93 +486,7 @@ def __len__(self):
|
487 | 486 | def __sizeof__(self):
|
488 | 487 | return self.nbytes
|
489 | 488 |
|
490 |
| - def __getitem__(self, index): |
491 |
| - if not isinstance(index, tuple): |
492 |
| - if isinstance(index, str): |
493 |
| - data = self.data[index] |
494 |
| - idx = np.where(data) |
495 |
| - coords = list(self.coords[:, idx[0]]) |
496 |
| - coords.extend(idx[1:]) |
497 |
| - |
498 |
| - return COO(coords, data[idx].flatten(), |
499 |
| - shape=self.shape + self.data.dtype[index].shape, |
500 |
| - has_duplicates=self.has_duplicates, |
501 |
| - sorted=self.sorted) |
502 |
| - else: |
503 |
| - index = (index,) |
504 |
| - |
505 |
| - last_ellipsis = len(index) > 0 and index[-1] is Ellipsis |
506 |
| - index = normalize_index(index, self.shape) |
507 |
| - if len(index) != 0 and all(not isinstance(ind, Iterable) and ind == slice(None) for ind in index): |
508 |
| - return self |
509 |
| - mask = np.ones(self.nnz, dtype=np.bool) |
510 |
| - for i, ind in enumerate([i for i in index if i is not None]): |
511 |
| - if not isinstance(ind, Iterable) and ind == slice(None): |
512 |
| - continue |
513 |
| - mask &= _mask(self.coords[i], ind, self.shape[i]) |
514 |
| - |
515 |
| - n = mask.sum() |
516 |
| - coords = [] |
517 |
| - shape = [] |
518 |
| - i = 0 |
519 |
| - for ind in index: |
520 |
| - if isinstance(ind, numbers.Integral): |
521 |
| - i += 1 |
522 |
| - continue |
523 |
| - elif isinstance(ind, slice): |
524 |
| - step = ind.step if ind.step is not None else 1 |
525 |
| - if step > 0: |
526 |
| - start = ind.start if ind.start is not None else 0 |
527 |
| - start = max(start, 0) |
528 |
| - stop = ind.stop if ind.stop is not None else self.shape[i] |
529 |
| - stop = min(stop, self.shape[i]) |
530 |
| - if start > stop: |
531 |
| - start = stop |
532 |
| - shape.append((stop - start + step - 1) // step) |
533 |
| - else: |
534 |
| - start = ind.start or self.shape[i] - 1 |
535 |
| - stop = ind.stop if ind.stop is not None else -1 |
536 |
| - start = min(start, self.shape[i] - 1) |
537 |
| - stop = max(stop, -1) |
538 |
| - if start < stop: |
539 |
| - start = stop |
540 |
| - shape.append((start - stop - step - 1) // (-step)) |
541 |
| - |
542 |
| - dt = np.min_scalar_type(min(-(dim - 1) if dim != 0 else -1 for dim in shape)) |
543 |
| - coords.append((self.coords[i, mask].astype(dt) - start) // step) |
544 |
| - i += 1 |
545 |
| - elif isinstance(ind, Iterable): |
546 |
| - old = self.coords[i][mask] |
547 |
| - new = np.empty(shape=old.shape, dtype=old.dtype) |
548 |
| - for j, item in enumerate(ind): |
549 |
| - new[old == item] = j |
550 |
| - coords.append(new) |
551 |
| - shape.append(len(ind)) |
552 |
| - i += 1 |
553 |
| - elif ind is None: |
554 |
| - coords.append(np.zeros(n)) |
555 |
| - shape.append(1) |
556 |
| - |
557 |
| - for j in range(i, self.ndim): |
558 |
| - coords.append(self.coords[j][mask]) |
559 |
| - shape.append(self.shape[j]) |
560 |
| - |
561 |
| - if coords: |
562 |
| - coords = np.stack(coords, axis=0) |
563 |
| - else: |
564 |
| - if last_ellipsis: |
565 |
| - coords = np.empty((0, np.sum(mask)), dtype=np.uint8) |
566 |
| - else: |
567 |
| - if np.sum(mask) != 0: |
568 |
| - return self.data[mask][0] |
569 |
| - else: |
570 |
| - return _zero_of_dtype(self.dtype)[()] |
571 |
| - shape = tuple(shape) |
572 |
| - data = self.data[mask] |
573 |
| - |
574 |
| - return COO(coords, data, shape=shape, |
575 |
| - has_duplicates=self.has_duplicates, |
576 |
| - sorted=self.sorted) |
| 489 | + __getitem__ = getitem |
577 | 490 |
|
578 | 491 | def __str__(self):
|
579 | 492 | return "<COO: shape=%s, dtype=%s, nnz=%d, sorted=%s, duplicates=%s>" % (
|
@@ -1572,28 +1485,6 @@ def _keepdims(original, new, axis):
|
1572 | 1485 | return new.reshape(shape)
|
1573 | 1486 |
|
1574 | 1487 |
|
1575 |
| -def _mask(coords, idx, shape): |
1576 |
| - if isinstance(idx, numbers.Integral): |
1577 |
| - return coords == idx |
1578 |
| - elif isinstance(idx, slice): |
1579 |
| - step = idx.step if idx.step is not None else 1 |
1580 |
| - if step > 0: |
1581 |
| - start = idx.start if idx.start is not None else 0 |
1582 |
| - stop = idx.stop if idx.stop is not None else shape |
1583 |
| - return (coords >= start) & (coords < stop) & \ |
1584 |
| - (coords % step == start % step) |
1585 |
| - else: |
1586 |
| - start = idx.start if idx.start is not None else (shape - 1) |
1587 |
| - stop = idx.stop if idx.stop is not None else -1 |
1588 |
| - return (coords <= start) & (coords > stop) & \ |
1589 |
| - (coords % step == start % step) |
1590 |
| - elif isinstance(idx, Iterable): |
1591 |
| - mask = np.zeros(len(coords), dtype=np.bool) |
1592 |
| - for item in idx: |
1593 |
| - mask |= _mask(coords, item, shape) |
1594 |
| - return mask |
1595 |
| - |
1596 |
| - |
1597 | 1488 | def _grouped_reduce(x, groups, method, **kwargs):
|
1598 | 1489 | """
|
1599 | 1490 | Performs a :code:`ufunc` grouped reduce.
|
|
0 commit comments