Skip to content

ENH: switch Dataset and DataArray to use explicit indexes #2639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 1 addition & 38 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _update_coords(self, coords):
self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dict(dims)
self._data._indexes = None

def __delitem__(self, key):
if key in self:
Expand Down Expand Up @@ -276,44 +277,6 @@ def __iter__(self):
return iter(self._data._level_coords)


class Indexes(Mapping, formatting.ReprMixin):
"""Ordered Mapping[str, pandas.Index] for xarray objects.
"""

def __init__(self, variables, sizes):
"""Not for public consumption.

Parameters
----------
variables : OrderedDict[Any, Variable]
Reference to OrderedDict holding variable objects. Should be the
same dictionary used by the source object.
sizes : OrderedDict[Any, int]
Map from dimension names to sizes.
"""
self._variables = variables
self._sizes = sizes

def __iter__(self):
for key in self._sizes:
if key in self._variables:
yield key

def __len__(self):
return sum(key in self._variables for key in self._sizes)

def __contains__(self, key):
return key in self._sizes and key in self._variables

def __getitem__(self, key):
if key not in self._sizes:
raise KeyError(key)
return self._variables[key].to_index()

def __unicode__(self):
return formatting.indexes_repr(self)


def assert_coordinate_consistent(obj, coords):
""" Maeke sure the dimension coordinate of obj is
consistent with coords.
Expand Down
15 changes: 11 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from .alignment import align, reindex_like_indexers
from .common import AbstractArray, DataWithCoords
from .coordinates import (
DataArrayCoordinates, Indexes, LevelCoordinatesSource,
DataArrayCoordinates, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
from .dataset import Dataset, merge_indexes, split_indexes
from .formatting import format_item
from .indexes import default_indexes, Indexes
from .options import OPTIONS
from .pycompat import OrderedDict, basestring, iteritems, range, zip
from .utils import (
Expand Down Expand Up @@ -165,7 +166,7 @@ class DataArray(AbstractArray, DataWithCoords):
dt = property(DatetimeAccessor)

def __init__(self, data, coords=None, dims=None, name=None,
attrs=None, encoding=None, fastpath=False):
attrs=None, encoding=None, indexes=None, fastpath=False):
"""
Parameters
----------
Expand Down Expand Up @@ -237,6 +238,10 @@ def __init__(self, data, coords=None, dims=None, name=None,
self._coords = coords
self._name = name

# TODO(shoyer): document this argument, once it becomes part of the
# public interface.
self._indexes = indexes

self._file_obj = None

self._initialized = True
Expand Down Expand Up @@ -534,9 +539,11 @@ def encoding(self, value):

@property
def indexes(self):
"""OrderedDict of pandas.Index objects used for label based indexing
"""Mapping of pandas.Index objects used for label based indexing
"""
return Indexes(self._coords, self.sizes)
if self._indexes is None:
self._indexes = default_indexes(self._coords, self.dims)
return Indexes(self._indexes)

@property
def coords(self):
Expand Down
27 changes: 19 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
import xarray as xr

from . import (
alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops,
pdcompat, resample, rolling, utils)
alignment, dtypes, duck_array_ops, formatting, groupby,
indexing, ops, pdcompat, resample, rolling, utils)
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from .alignment import align
from .common import (
ALL_DIMS, DataWithCoords, ImplementsDatasetReduce,
_contains_datetime_like_objects)
from .coordinates import (
DatasetCoordinates, Indexes, LevelCoordinatesSource,
DatasetCoordinates, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
from .indexes import Indexes, default_indexes
from .merge import (
dataset_merge_method, dataset_update_method, merge_data_and_coords,
merge_variables)
Expand Down Expand Up @@ -364,6 +365,10 @@ def __init__(self, data_vars=None, coords=None, attrs=None,
coords = {}
if data_vars is not None or coords is not None:
self._set_init_vars_and_dims(data_vars, coords, compat)

# TODO(shoyer): expose indexes as a public argument in __init__
self._indexes = None

if attrs is not None:
self.attrs = attrs
self._encoding = None
Expand Down Expand Up @@ -642,14 +647,15 @@ def persist(self, **kwargs):

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
file_obj=None, encoding=None):
indexes=None, file_obj=None, encoding=None):
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
obj = object.__new__(cls)
obj._variables = variables
obj._coord_names = coord_names
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
obj._file_obj = file_obj
obj._encoding = encoding
Expand All @@ -664,7 +670,8 @@ def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
return cls._construct_direct(variables, coord_names, dims, attrs)

def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
attrs=__default_attrs, inplace=False):
attrs=__default_attrs, indexes=None,
inplace=False):
"""Fastpath constructor for internal use.

Preserves coord names and attributes. If not provided explicitly,
Expand Down Expand Up @@ -693,13 +700,15 @@ def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
self._coord_names = coord_names
if attrs is not self.__default_attrs:
self._attrs = attrs
self._indexes = indexes
obj = self
else:
if coord_names is None:
coord_names = self._coord_names.copy()
if attrs is self.__default_attrs:
attrs = self._attrs_copy()
obj = self._construct_direct(variables, coord_names, dims, attrs)
obj = self._construct_direct(
variables, coord_names, dims, attrs, indexes)
return obj

def _replace_indexes(self, indexes):
Expand Down Expand Up @@ -1064,9 +1073,11 @@ def identical(self, other):

@property
def indexes(self):
"""OrderedDict of pandas.Index objects used for label based indexing
"""Mapping of pandas.Index objects used for label based indexing
"""
return Indexes(self._variables, self._dims)
if self._indexes is None:
self._indexes = default_indexes(self._variables, self._dims)
return Indexes(self._indexes)

@property
def coords(self):
Expand Down
55 changes: 55 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import absolute_import, division, print_function
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
from collections import OrderedDict

from . import formatting


class Indexes(Mapping, formatting.ReprMixin):
"""Immutable proxy for Dataset or DataArrary indexes."""
def __init__(self, indexes):
"""Not for public consumption.

Parameters
----------
indexes : Dict[Any, pandas.Index]
Indexes held by this object.
"""
self._indexes = indexes

def __iter__(self):
return iter(self._indexes)

def __len__(self):
return len(self._indexes)

def __contains__(self, key):
return key in self._indexes

def __getitem__(self, key):
return self._indexes[key]

def __unicode__(self):
return formatting.indexes_repr(self)


def default_indexes(coords, dims):
"""Default indexes for a Dataset/DataArray.

Parameters
----------
coords : Mapping[Any, xarray.Variable]
Coordinate variables from which to draw default indexes.
dims : iterable
Iterable of dimension names.

Returns
-------
Mapping[Any, pandas.Index] mapping indexing keys (levels/dimension names)
to indexes used for indexing along that dimension.
"""
return OrderedDict((key, coords[key].to_index())
for key in dims if key in coords)