17
17
from zarr .compat import reduce
18
18
from zarr .codecs import AsType , get_codec
19
19
from zarr .indexing import OIndex , OrthogonalIndexer , BasicIndexer , VIndex , CoordinateIndexer , \
20
- MaskIndexer
20
+ MaskIndexer , check_fields , pop_fields , ensure_tuple
21
+
22
+
23
+ def is_scalar (value , dtype ):
24
+ if np .isscalar (value ):
25
+ return True
26
+ if isinstance (value , tuple ) and dtype .names and len (value ) == len (dtype .names ):
27
+ return True
28
+ return False
21
29
22
30
23
31
class Array (object ):
@@ -465,19 +473,10 @@ def __getitem__(self, selection):
465
473
466
474
"""
467
475
468
- if len (self ._shape ) == 0 :
469
- return self ._get_basic_selection_zd (selection )
470
-
471
- elif len (self ._shape ) == 1 :
472
- # safe to do "fancy" indexing, no ambiguity
473
- return self .get_orthogonal_selection (selection )
474
-
475
- else :
476
- # "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
477
- # force people to go through explicit methods
478
- return self .get_basic_selection (selection )
476
+ fields , selection = pop_fields (selection )
477
+ return self .get_basic_selection (selection , fields = fields )
479
478
480
- def get_basic_selection (self , selection , out = None ):
479
+ def get_basic_selection (self , selection , out = None , fields = None ):
481
480
"""TODO"""
482
481
483
482
# refresh metadata
@@ -486,15 +485,16 @@ def get_basic_selection(self, selection, out=None):
486
485
487
486
# handle zero-dimensional arrays
488
487
if self ._shape == ():
489
- return self ._get_basic_selection_zd (selection , out = out )
488
+ return self ._get_basic_selection_zd (selection = selection , out = out , fields = fields )
490
489
else :
491
- return self ._get_basic_selection_nd (selection , out = out )
490
+ return self ._get_basic_selection_nd (selection = selection , out = out , fields = fields )
492
491
493
- def _get_basic_selection_zd (self , selection , out = None ):
492
+ def _get_basic_selection_zd (self , selection , out = None , fields = None ):
494
493
# special case basic selection for zero-dimensional array
495
494
496
495
# check selection is valid
497
- if selection not in ((), Ellipsis ):
496
+ selection = ensure_tuple (selection )
497
+ if selection not in ((), (Ellipsis ,)):
498
498
raise IndexError ('too many indices for array' )
499
499
500
500
try :
@@ -519,17 +519,21 @@ def _get_basic_selection_zd(self, selection, out=None):
519
519
else :
520
520
out [selection ] = chunk [selection ]
521
521
522
+ # handle fields
523
+ if fields :
524
+ out = out [fields ]
525
+
522
526
return out
523
527
524
- def _get_basic_selection_nd (self , selection , out = None ):
528
+ def _get_basic_selection_nd (self , selection , out = None , fields = None ):
525
529
# implementation of basic selection for array with at least one dimension
526
530
527
531
# setup indexer
528
532
indexer = BasicIndexer (selection , self )
529
533
530
- return self ._get_selection (indexer , out = out )
534
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
531
535
532
- def get_orthogonal_selection (self , selection , out = None ):
536
+ def get_orthogonal_selection (self , selection , out = None , fields = None ):
533
537
"""TODO"""
534
538
535
539
# refresh metadata
@@ -539,9 +543,9 @@ def get_orthogonal_selection(self, selection, out=None):
539
543
# setup indexer
540
544
indexer = OrthogonalIndexer (selection , self )
541
545
542
- return self ._get_selection (indexer , out = out )
546
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
543
547
544
- def get_coordinate_selection (self , selection , out = None ):
548
+ def get_coordinate_selection (self , selection , out = None , fields = None ):
545
549
"""TODO"""
546
550
547
551
# refresh metadata
@@ -551,9 +555,9 @@ def get_coordinate_selection(self, selection, out=None):
551
555
# setup indexer
552
556
indexer = CoordinateIndexer (selection , self )
553
557
554
- return self ._get_selection (indexer , out = out )
558
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
555
559
556
- def get_mask_selection (self , selection , out = None ):
560
+ def get_mask_selection (self , selection , out = None , fields = None ):
557
561
"""TODO"""
558
562
559
563
# refresh metadata
@@ -563,9 +567,9 @@ def get_mask_selection(self, selection, out=None):
563
567
# setup indexer
564
568
indexer = MaskIndexer (selection , self )
565
569
566
- return self ._get_selection (indexer , out = out )
570
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
567
571
568
- def _get_selection (self , indexer , out = None ):
572
+ def _get_selection (self , indexer , out = None , fields = None ):
569
573
570
574
# We iterate over all chunks which overlap the selection and thus contain data that needs
571
575
# to be extracted. Each chunk is processed in turn, extracting the necessary data and
@@ -574,25 +578,28 @@ def _get_selection(self, indexer, out=None):
574
578
# N.B., it is an important optimisation that we only visit chunks which overlap the
575
579
# selection. This minimises the nuimber of iterations in the main for loop.
576
580
581
+ # check fields are sensible
582
+ out_dtype = check_fields (fields , self ._dtype )
583
+
577
584
# determine output shape
578
- sel_shape = indexer .shape
585
+ out_shape = indexer .shape
579
586
580
587
# setup output array
581
588
if out is None :
582
- out = np .empty (sel_shape , dtype = self . _dtype , order = self ._order )
589
+ out = np .empty (out_shape , dtype = out_dtype , order = self ._order )
583
590
else :
584
591
# validate 'out' parameter
585
592
if not hasattr (out , 'shape' ):
586
593
raise TypeError ('out must be an array-like object' )
587
- if out .shape != sel_shape :
594
+ if out .shape != out_shape :
588
595
raise ValueError ('out has wrong shape for selection' )
589
596
590
597
# iterate over chunks
591
598
for chunk_coords , chunk_selection , out_selection in indexer :
592
599
593
600
# load chunk selection into output array
594
601
self ._chunk_getitem (chunk_coords , chunk_selection , out , out_selection ,
595
- drop_axes = indexer .drop_axes )
602
+ drop_axes = indexer .drop_axes , fields = fields )
596
603
597
604
if out .shape :
598
605
return out
@@ -658,19 +665,10 @@ def __setitem__(self, selection, value):
658
665
659
666
"""
660
667
661
- if len ( self . _shape ) == 0 :
662
- self ._set_basic_selection_zd (selection , value )
668
+ fields , selection = pop_fields ( selection )
669
+ self .set_basic_selection (selection , value , fields = fields )
663
670
664
- elif len (self ._shape ) == 1 :
665
- # safe to do "fancy" indexing, no ambiguity
666
- self .set_orthogonal_selection (selection , value )
667
-
668
- else :
669
- # "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
670
- # force people to go through explicit methods
671
- self .set_basic_selection (selection , value )
672
-
673
- def set_basic_selection (self , selection , value ):
671
+ def set_basic_selection (self , selection , value , fields = None ):
674
672
"""TODO"""
675
673
676
674
# guard conditions
@@ -683,11 +681,11 @@ def set_basic_selection(self, selection, value):
683
681
684
682
# handle zero-dimensional arrays
685
683
if self ._shape == ():
686
- return self ._set_basic_selection_zd (selection , value )
684
+ return self ._set_basic_selection_zd (selection , value , fields = fields )
687
685
else :
688
- return self ._set_basic_selection_nd (selection , value )
686
+ return self ._set_basic_selection_nd (selection , value , fields = fields )
689
687
690
- def set_orthogonal_selection (self , selection , value ):
688
+ def set_orthogonal_selection (self , selection , value , fields = None ):
691
689
"""TODO"""
692
690
693
691
# guard conditions
@@ -701,9 +699,9 @@ def set_orthogonal_selection(self, selection, value):
701
699
# setup indexer
702
700
indexer = OrthogonalIndexer (selection , self )
703
701
704
- self ._set_selection (indexer , value )
702
+ self ._set_selection (indexer , value , fields = fields )
705
703
706
- def set_coordinate_selection (self , selection , value ):
704
+ def set_coordinate_selection (self , selection , value , fields = None ):
707
705
"""TODO"""
708
706
709
707
# guard conditions
@@ -717,9 +715,9 @@ def set_coordinate_selection(self, selection, value):
717
715
# setup indexer
718
716
indexer = CoordinateIndexer (selection , self )
719
717
720
- self ._set_selection (indexer , value )
718
+ self ._set_selection (indexer , value , fields = fields )
721
719
722
- def set_mask_selection (self , selection , value ):
720
+ def set_mask_selection (self , selection , value , fields = None ):
723
721
"""TODO"""
724
722
725
723
# guard conditions
@@ -733,13 +731,17 @@ def set_mask_selection(self, selection, value):
733
731
# setup indexer
734
732
indexer = MaskIndexer (selection , self )
735
733
736
- self ._set_selection (indexer , value )
734
+ self ._set_selection (indexer , value , fields = fields )
737
735
738
- def _set_basic_selection_zd (self , selection , value ):
736
+ def _set_basic_selection_zd (self , selection , value , fields = None ):
739
737
# special case __setitem__ for zero-dimensional array
740
738
739
+ if fields :
740
+ raise IndexError ('fields not supported for 0d array' )
741
+
741
742
# check item is valid
742
- if selection not in ((), Ellipsis ):
743
+ selection = ensure_tuple (selection )
744
+ if selection not in ((), (Ellipsis ,)):
743
745
raise IndexError ('too many indices for array' )
744
746
745
747
# setup data to store
@@ -756,15 +758,15 @@ def _set_basic_selection_zd(self, selection, value):
756
758
cdata = self ._encode_chunk (arr )
757
759
self .chunk_store [ckey ] = cdata
758
760
759
- def _set_basic_selection_nd (self , selection , value ):
761
+ def _set_basic_selection_nd (self , selection , value , fields = None ):
760
762
# implementation of __setitem__ for array with at least one dimension
761
763
762
764
# setup indexer
763
765
indexer = BasicIndexer (selection , self )
764
766
765
- self ._set_selection (indexer , value )
767
+ self ._set_selection (indexer , value , fields = fields )
766
768
767
- def _set_selection (self , indexer , value ):
769
+ def _set_selection (self , indexer , value , fields = None ):
768
770
769
771
# We iterate over all chunks which overlap the selection and thus contain data that needs
770
772
# to be replaced. Each chunk is processed in turn, extracting the necessary data from the
@@ -773,15 +775,20 @@ def _set_selection(self, indexer, value):
773
775
# N.B., it is an important optimisation that we only visit chunks which overlap the
774
776
# selection. This minimises the nuimber of iterations in the main for loop.
775
777
778
+ # check fields are sensible
779
+ check_fields (fields , self ._dtype )
780
+ if fields and isinstance (fields , list ):
781
+ raise ValueError ('multi-field assignment is not supported' )
782
+
776
783
# determine indices of chunks overlapping the selection
777
784
sel_shape = indexer .shape
778
785
779
786
# check value shape
780
- if np . isscalar (value ):
787
+ if is_scalar (value , self . _dtype ):
781
788
pass
782
789
else :
783
790
if not hasattr (value , 'shape' ):
784
- raise TypeError ( ' value must be an array-like object' )
791
+ value = np . asarray ( value )
785
792
if value .shape != sel_shape :
786
793
raise ValueError ('value has wrong shape for selection; expected {}, got {}'
787
794
.format (sel_shape , value .shape ))
@@ -790,7 +797,7 @@ def _set_selection(self, indexer, value):
790
797
for chunk_coords , chunk_selection , out_selection in indexer :
791
798
792
799
# extract data to store
793
- if np . isscalar (value ):
800
+ if is_scalar (value , self . _dtype ):
794
801
chunk_value = value
795
802
else :
796
803
chunk_value = value [out_selection ]
@@ -802,9 +809,10 @@ def _set_selection(self, indexer, value):
802
809
chunk_value = chunk_value [item ]
803
810
804
811
# put data
805
- self ._chunk_setitem (chunk_coords , chunk_selection , chunk_value )
812
+ self ._chunk_setitem (chunk_coords , chunk_selection , chunk_value , fields = fields )
806
813
807
- def _chunk_getitem (self , chunk_coords , chunk_selection , out , out_selection , drop_axes = None ):
814
+ def _chunk_getitem (self , chunk_coords , chunk_selection , out , out_selection , drop_axes = None ,
815
+ fields = None ):
808
816
"""Obtain part or whole of a chunk.
809
817
810
818
Parameters
@@ -819,6 +827,8 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
819
827
Location of region within output array to store results in.
820
828
drop_axes : tuple of ints
821
829
Axes to squeeze out of the chunk.
830
+ fields
831
+ TODO
822
832
823
833
"""
824
834
@@ -838,10 +848,11 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
838
848
839
849
else :
840
850
841
- if isinstance (out , np .ndarray ) and \
842
- isinstance (out_selection , slice ) and \
843
- is_total_slice (chunk_selection , self ._chunks ) and \
844
- not self ._filters :
851
+ if (isinstance (out , np .ndarray ) and
852
+ not fields and
853
+ isinstance (out_selection , slice ) and
854
+ is_total_slice (chunk_selection , self ._chunks ) and
855
+ not self ._filters ):
845
856
846
857
dest = out [out_selection ]
847
858
contiguous = ((self ._order == 'C' and dest .flags .c_contiguous ) or
@@ -864,13 +875,17 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
864
875
# decode chunk
865
876
chunk = self ._decode_chunk (cdata )
866
877
867
- # set data in output array
878
+ # select data from chunk
879
+ if fields :
880
+ chunk = chunk [fields ]
868
881
tmp = chunk [chunk_selection ]
869
882
if drop_axes :
870
883
tmp = np .squeeze (tmp , axis = drop_axes )
884
+
885
+ # store selected data in output
871
886
out [out_selection ] = tmp
872
887
873
- def _chunk_setitem (self , chunk_coords , chunk_selection , value ):
888
+ def _chunk_setitem (self , chunk_coords , chunk_selection , value , fields = None ):
874
889
"""Replace part or whole of a chunk.
875
890
876
891
Parameters
@@ -886,25 +901,25 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value):
886
901
887
902
# synchronization
888
903
if self ._synchronizer is None :
889
- self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value )
904
+ self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value , fields = fields )
890
905
else :
891
906
# synchronize on the chunk
892
907
ckey = self ._chunk_key (chunk_coords )
893
908
with self ._synchronizer [ckey ]:
894
- self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value )
909
+ self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value , fields = fields )
895
910
896
- def _chunk_setitem_nosync (self , chunk_coords , chunk_selection , value ):
911
+ def _chunk_setitem_nosync (self , chunk_coords , chunk_selection , value , fields = None ):
897
912
898
913
# obtain key for chunk storage
899
914
ckey = self ._chunk_key (chunk_coords )
900
915
901
- if is_total_slice (chunk_selection , self ._chunks ):
916
+ if is_total_slice (chunk_selection , self ._chunks ) and not fields :
902
917
# totally replace chunk
903
918
904
919
# optimization: we are completely replacing the chunk, so no need
905
920
# to access the existing chunk data
906
921
907
- if np . isscalar (value ):
922
+ if is_scalar (value , self . _dtype ):
908
923
909
924
# setup array filled with value
910
925
chunk = np .empty (self ._chunks , dtype = self ._dtype , order = self ._order )
@@ -955,7 +970,12 @@ def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value):
955
970
chunk = chunk .copy (order = 'K' )
956
971
957
972
# modify
958
- chunk [chunk_selection ] = value
973
+ if fields :
974
+ # N.B., currently multi-field assignment is not supported in numpy, so this only
975
+ # works for a single field
976
+ chunk [fields ][chunk_selection ] = value
977
+ else :
978
+ chunk [chunk_selection ] = value
959
979
960
980
# encode chunk
961
981
cdata = self ._encode_chunk (chunk )
0 commit comments