Skip to content

Commit c814c97

Browse files
committed
add support for fields with selection, resolves #112
1 parent 983355d commit c814c97

File tree

3 files changed

+337
-200
lines changed

3 files changed

+337
-200
lines changed

zarr/core.py

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
from zarr.compat import reduce
1818
from zarr.codecs import AsType, get_codec
1919
from zarr.indexing import OIndex, OrthogonalIndexer, BasicIndexer, VIndex, CoordinateIndexer, \
20-
MaskIndexer
20+
MaskIndexer, check_fields, pop_fields, ensure_tuple
21+
22+
23+
def is_scalar(value, dtype):
24+
if np.isscalar(value):
25+
return True
26+
if isinstance(value, tuple) and dtype.names and len(value) == len(dtype.names):
27+
return True
28+
return False
2129

2230

2331
class Array(object):
@@ -465,19 +473,10 @@ def __getitem__(self, selection):
465473
466474
"""
467475

468-
if len(self._shape) == 0:
469-
return self._get_basic_selection_zd(selection)
470-
471-
elif len(self._shape) == 1:
472-
# safe to do "fancy" indexing, no ambiguity
473-
return self.get_orthogonal_selection(selection)
474-
475-
else:
476-
# "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
477-
# force people to go through explicit methods
478-
return self.get_basic_selection(selection)
476+
fields, selection = pop_fields(selection)
477+
return self.get_basic_selection(selection, fields=fields)
479478

480-
def get_basic_selection(self, selection, out=None):
479+
def get_basic_selection(self, selection, out=None, fields=None):
481480
"""TODO"""
482481

483482
# refresh metadata
@@ -486,15 +485,16 @@ def get_basic_selection(self, selection, out=None):
486485

487486
# handle zero-dimensional arrays
488487
if self._shape == ():
489-
return self._get_basic_selection_zd(selection, out=out)
488+
return self._get_basic_selection_zd(selection=selection, out=out, fields=fields)
490489
else:
491-
return self._get_basic_selection_nd(selection, out=out)
490+
return self._get_basic_selection_nd(selection=selection, out=out, fields=fields)
492491

493-
def _get_basic_selection_zd(self, selection, out=None):
492+
def _get_basic_selection_zd(self, selection, out=None, fields=None):
494493
# special case basic selection for zero-dimensional array
495494

496495
# check selection is valid
497-
if selection not in ((), Ellipsis):
496+
selection = ensure_tuple(selection)
497+
if selection not in ((), (Ellipsis,)):
498498
raise IndexError('too many indices for array')
499499

500500
try:
@@ -519,17 +519,21 @@ def _get_basic_selection_zd(self, selection, out=None):
519519
else:
520520
out[selection] = chunk[selection]
521521

522+
# handle fields
523+
if fields:
524+
out = out[fields]
525+
522526
return out
523527

524-
def _get_basic_selection_nd(self, selection, out=None):
528+
def _get_basic_selection_nd(self, selection, out=None, fields=None):
525529
# implementation of basic selection for array with at least one dimension
526530

527531
# setup indexer
528532
indexer = BasicIndexer(selection, self)
529533

530-
return self._get_selection(indexer, out=out)
534+
return self._get_selection(indexer=indexer, out=out, fields=fields)
531535

532-
def get_orthogonal_selection(self, selection, out=None):
536+
def get_orthogonal_selection(self, selection, out=None, fields=None):
533537
"""TODO"""
534538

535539
# refresh metadata
@@ -539,9 +543,9 @@ def get_orthogonal_selection(self, selection, out=None):
539543
# setup indexer
540544
indexer = OrthogonalIndexer(selection, self)
541545

542-
return self._get_selection(indexer, out=out)
546+
return self._get_selection(indexer=indexer, out=out, fields=fields)
543547

544-
def get_coordinate_selection(self, selection, out=None):
548+
def get_coordinate_selection(self, selection, out=None, fields=None):
545549
"""TODO"""
546550

547551
# refresh metadata
@@ -551,9 +555,9 @@ def get_coordinate_selection(self, selection, out=None):
551555
# setup indexer
552556
indexer = CoordinateIndexer(selection, self)
553557

554-
return self._get_selection(indexer, out=out)
558+
return self._get_selection(indexer=indexer, out=out, fields=fields)
555559

556-
def get_mask_selection(self, selection, out=None):
560+
def get_mask_selection(self, selection, out=None, fields=None):
557561
"""TODO"""
558562

559563
# refresh metadata
@@ -563,9 +567,9 @@ def get_mask_selection(self, selection, out=None):
563567
# setup indexer
564568
indexer = MaskIndexer(selection, self)
565569

566-
return self._get_selection(indexer, out=out)
570+
return self._get_selection(indexer=indexer, out=out, fields=fields)
567571

568-
def _get_selection(self, indexer, out=None):
572+
def _get_selection(self, indexer, out=None, fields=None):
569573

570574
# We iterate over all chunks which overlap the selection and thus contain data that needs
571575
# to be extracted. Each chunk is processed in turn, extracting the necessary data and
@@ -574,25 +578,28 @@ def _get_selection(self, indexer, out=None):
574578
# N.B., it is an important optimisation that we only visit chunks which overlap the
575579
# selection. This minimises the nuimber of iterations in the main for loop.
576580

581+
# check fields are sensible
582+
out_dtype = check_fields(fields, self._dtype)
583+
577584
# determine output shape
578-
sel_shape = indexer.shape
585+
out_shape = indexer.shape
579586

580587
# setup output array
581588
if out is None:
582-
out = np.empty(sel_shape, dtype=self._dtype, order=self._order)
589+
out = np.empty(out_shape, dtype=out_dtype, order=self._order)
583590
else:
584591
# validate 'out' parameter
585592
if not hasattr(out, 'shape'):
586593
raise TypeError('out must be an array-like object')
587-
if out.shape != sel_shape:
594+
if out.shape != out_shape:
588595
raise ValueError('out has wrong shape for selection')
589596

590597
# iterate over chunks
591598
for chunk_coords, chunk_selection, out_selection in indexer:
592599

593600
# load chunk selection into output array
594601
self._chunk_getitem(chunk_coords, chunk_selection, out, out_selection,
595-
drop_axes=indexer.drop_axes)
602+
drop_axes=indexer.drop_axes, fields=fields)
596603

597604
if out.shape:
598605
return out
@@ -658,19 +665,10 @@ def __setitem__(self, selection, value):
658665
659666
"""
660667

661-
if len(self._shape) == 0:
662-
self._set_basic_selection_zd(selection, value)
668+
fields, selection = pop_fields(selection)
669+
self.set_basic_selection(selection, value, fields=fields)
663670

664-
elif len(self._shape) == 1:
665-
# safe to do "fancy" indexing, no ambiguity
666-
self.set_orthogonal_selection(selection, value)
667-
668-
else:
669-
# "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
670-
# force people to go through explicit methods
671-
self.set_basic_selection(selection, value)
672-
673-
def set_basic_selection(self, selection, value):
671+
def set_basic_selection(self, selection, value, fields=None):
674672
"""TODO"""
675673

676674
# guard conditions
@@ -683,11 +681,11 @@ def set_basic_selection(self, selection, value):
683681

684682
# handle zero-dimensional arrays
685683
if self._shape == ():
686-
return self._set_basic_selection_zd(selection, value)
684+
return self._set_basic_selection_zd(selection, value, fields=fields)
687685
else:
688-
return self._set_basic_selection_nd(selection, value)
686+
return self._set_basic_selection_nd(selection, value, fields=fields)
689687

690-
def set_orthogonal_selection(self, selection, value):
688+
def set_orthogonal_selection(self, selection, value, fields=None):
691689
"""TODO"""
692690

693691
# guard conditions
@@ -701,9 +699,9 @@ def set_orthogonal_selection(self, selection, value):
701699
# setup indexer
702700
indexer = OrthogonalIndexer(selection, self)
703701

704-
self._set_selection(indexer, value)
702+
self._set_selection(indexer, value, fields=fields)
705703

706-
def set_coordinate_selection(self, selection, value):
704+
def set_coordinate_selection(self, selection, value, fields=None):
707705
"""TODO"""
708706

709707
# guard conditions
@@ -717,9 +715,9 @@ def set_coordinate_selection(self, selection, value):
717715
# setup indexer
718716
indexer = CoordinateIndexer(selection, self)
719717

720-
self._set_selection(indexer, value)
718+
self._set_selection(indexer, value, fields=fields)
721719

722-
def set_mask_selection(self, selection, value):
720+
def set_mask_selection(self, selection, value, fields=None):
723721
"""TODO"""
724722

725723
# guard conditions
@@ -733,13 +731,17 @@ def set_mask_selection(self, selection, value):
733731
# setup indexer
734732
indexer = MaskIndexer(selection, self)
735733

736-
self._set_selection(indexer, value)
734+
self._set_selection(indexer, value, fields=fields)
737735

738-
def _set_basic_selection_zd(self, selection, value):
736+
def _set_basic_selection_zd(self, selection, value, fields=None):
739737
# special case __setitem__ for zero-dimensional array
740738

739+
if fields:
740+
raise IndexError('fields not supported for 0d array')
741+
741742
# check item is valid
742-
if selection not in ((), Ellipsis):
743+
selection = ensure_tuple(selection)
744+
if selection not in ((), (Ellipsis,)):
743745
raise IndexError('too many indices for array')
744746

745747
# setup data to store
@@ -756,15 +758,15 @@ def _set_basic_selection_zd(self, selection, value):
756758
cdata = self._encode_chunk(arr)
757759
self.chunk_store[ckey] = cdata
758760

759-
def _set_basic_selection_nd(self, selection, value):
761+
def _set_basic_selection_nd(self, selection, value, fields=None):
760762
# implementation of __setitem__ for array with at least one dimension
761763

762764
# setup indexer
763765
indexer = BasicIndexer(selection, self)
764766

765-
self._set_selection(indexer, value)
767+
self._set_selection(indexer, value, fields=fields)
766768

767-
def _set_selection(self, indexer, value):
769+
def _set_selection(self, indexer, value, fields=None):
768770

769771
# We iterate over all chunks which overlap the selection and thus contain data that needs
770772
# to be replaced. Each chunk is processed in turn, extracting the necessary data from the
@@ -773,15 +775,20 @@ def _set_selection(self, indexer, value):
773775
# N.B., it is an important optimisation that we only visit chunks which overlap the
774776
# selection. This minimises the nuimber of iterations in the main for loop.
775777

778+
# check fields are sensible
779+
check_fields(fields, self._dtype)
780+
if fields and isinstance(fields, list):
781+
raise ValueError('multi-field assignment is not supported')
782+
776783
# determine indices of chunks overlapping the selection
777784
sel_shape = indexer.shape
778785

779786
# check value shape
780-
if np.isscalar(value):
787+
if is_scalar(value, self._dtype):
781788
pass
782789
else:
783790
if not hasattr(value, 'shape'):
784-
raise TypeError('value must be an array-like object')
791+
value = np.asarray(value)
785792
if value.shape != sel_shape:
786793
raise ValueError('value has wrong shape for selection; expected {}, got {}'
787794
.format(sel_shape, value.shape))
@@ -790,7 +797,7 @@ def _set_selection(self, indexer, value):
790797
for chunk_coords, chunk_selection, out_selection in indexer:
791798

792799
# extract data to store
793-
if np.isscalar(value):
800+
if is_scalar(value, self._dtype):
794801
chunk_value = value
795802
else:
796803
chunk_value = value[out_selection]
@@ -802,9 +809,10 @@ def _set_selection(self, indexer, value):
802809
chunk_value = chunk_value[item]
803810

804811
# put data
805-
self._chunk_setitem(chunk_coords, chunk_selection, chunk_value)
812+
self._chunk_setitem(chunk_coords, chunk_selection, chunk_value, fields=fields)
806813

807-
def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop_axes=None):
814+
def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop_axes=None,
815+
fields=None):
808816
"""Obtain part or whole of a chunk.
809817
810818
Parameters
@@ -819,6 +827,8 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
819827
Location of region within output array to store results in.
820828
drop_axes : tuple of ints
821829
Axes to squeeze out of the chunk.
830+
fields
831+
TODO
822832
823833
"""
824834

@@ -838,10 +848,11 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
838848

839849
else:
840850

841-
if isinstance(out, np.ndarray) and \
842-
isinstance(out_selection, slice) and \
843-
is_total_slice(chunk_selection, self._chunks) and \
844-
not self._filters:
851+
if (isinstance(out, np.ndarray) and
852+
not fields and
853+
isinstance(out_selection, slice) and
854+
is_total_slice(chunk_selection, self._chunks) and
855+
not self._filters):
845856

846857
dest = out[out_selection]
847858
contiguous = ((self._order == 'C' and dest.flags.c_contiguous) or
@@ -864,13 +875,17 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
864875
# decode chunk
865876
chunk = self._decode_chunk(cdata)
866877

867-
# set data in output array
878+
# select data from chunk
879+
if fields:
880+
chunk = chunk[fields]
868881
tmp = chunk[chunk_selection]
869882
if drop_axes:
870883
tmp = np.squeeze(tmp, axis=drop_axes)
884+
885+
# store selected data in output
871886
out[out_selection] = tmp
872887

873-
def _chunk_setitem(self, chunk_coords, chunk_selection, value):
888+
def _chunk_setitem(self, chunk_coords, chunk_selection, value, fields=None):
874889
"""Replace part or whole of a chunk.
875890
876891
Parameters
@@ -886,25 +901,25 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value):
886901

887902
# synchronization
888903
if self._synchronizer is None:
889-
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value)
904+
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value, fields=fields)
890905
else:
891906
# synchronize on the chunk
892907
ckey = self._chunk_key(chunk_coords)
893908
with self._synchronizer[ckey]:
894-
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value)
909+
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value, fields=fields)
895910

896-
def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value):
911+
def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value, fields=None):
897912

898913
# obtain key for chunk storage
899914
ckey = self._chunk_key(chunk_coords)
900915

901-
if is_total_slice(chunk_selection, self._chunks):
916+
if is_total_slice(chunk_selection, self._chunks) and not fields:
902917
# totally replace chunk
903918

904919
# optimization: we are completely replacing the chunk, so no need
905920
# to access the existing chunk data
906921

907-
if np.isscalar(value):
922+
if is_scalar(value, self._dtype):
908923

909924
# setup array filled with value
910925
chunk = np.empty(self._chunks, dtype=self._dtype, order=self._order)
@@ -955,7 +970,12 @@ def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value):
955970
chunk = chunk.copy(order='K')
956971

957972
# modify
958-
chunk[chunk_selection] = value
973+
if fields:
974+
# N.B., currently multi-field assignment is not supported in numpy, so this only
975+
# works for a single field
976+
chunk[fields][chunk_selection] = value
977+
else:
978+
chunk[chunk_selection] = value
959979

960980
# encode chunk
961981
cdata = self._encode_chunk(chunk)

0 commit comments

Comments
 (0)