-
Notifications
You must be signed in to change notification settings - Fork 18
Add Kvikio backend entrypoint #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
9deadb7
aa2dc91
7fb4b94
743fe7d
5d501e4
facf5f7
f3f5189
9c98d19
dd8bc57
d2da1e4
b87c3c2
87cb74e
d7394ef
1b23fef
ca0cf45
97260d6
5d27b26
85491d7
c470b97
95efa18
d684dad
ae2a7f1
15fbafd
f3df115
4e1857a
7345b61
e2b410e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import os | ||
|
||
import cupy as cp | ||
import numpy as np | ||
import zarr | ||
from xarray import Variable | ||
from xarray.backends import zarr as zarr_backend | ||
from xarray.backends.common import _normalize_path # TODO: can this be public | ||
from xarray.backends.store import StoreBackendEntrypoint | ||
from xarray.backends.zarr import ZarrArrayWrapper, ZarrBackendEntrypoint, ZarrStore | ||
from xarray.core import indexing | ||
from xarray.core.utils import close_on_error # TODO: can this be public. | ||
|
||
try: | ||
import kvikio.zarr | ||
|
||
has_kvikio = True | ||
except ImportError: | ||
has_kvikio = False | ||
|
||
|
||
class CupyZarrArrayWrapper(ZarrArrayWrapper): | ||
def __array__(self): | ||
return self.get_array() | ||
|
||
|
||
class EagerCupyZarrArrayWrapper(ZarrArrayWrapper): | ||
"""Used to wrap dimension coordinates.""" | ||
|
||
def __array__(self): | ||
return self.datastore.zarr_group[self.variable_name][:].get() | ||
|
||
def get_array(self): | ||
return np.asarray(self) | ||
|
||
|
||
class GDSZarrStore(ZarrStore): | ||
@classmethod | ||
def open_group( | ||
cls, | ||
store, | ||
mode="r", | ||
synchronizer=None, | ||
group=None, | ||
consolidated=False, | ||
consolidate_on_close=False, | ||
chunk_store=None, | ||
storage_options=None, | ||
append_dim=None, | ||
write_region=None, | ||
safe_chunks=True, | ||
stacklevel=2, | ||
): | ||
|
||
# zarr doesn't support pathlib.Path objects yet. zarr-python#601 | ||
if isinstance(store, os.PathLike): | ||
store = os.fspath(store) | ||
|
||
open_kwargs = dict( | ||
mode=mode, | ||
synchronizer=synchronizer, | ||
path=group, | ||
########## NEW STUFF | ||
meta_array=cp.empty(()), | ||
) | ||
open_kwargs["storage_options"] = storage_options | ||
|
||
# TODO: handle consolidated | ||
assert not consolidated | ||
|
||
if chunk_store: | ||
open_kwargs["chunk_store"] = chunk_store | ||
if consolidated is None: | ||
consolidated = False | ||
|
||
store = kvikio.zarr.GDSStore(store) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can refactor this to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PR welcome! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Debating on whether to start from scratch in a completely new branch, or rebase off of this one 😄 P.S. I'm starting some work over at https://github.com/weiji14/foss4g2023oceania for a conference talk on 18 Oct, hoping to get the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just build on top of this branch in a new PR. The optimization i was mentioning to save There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I got a bit too ambitious trying to work in support for LZ4 compression (via nvCOMP), and hit into some issues (see rapidsai/kvikio#297). There's some stable releases of RAPIDS AI kvikIO 23.10.00 and xarray 2023.10.0 now which should have some nice enhancements for this PR. I'll try to squeeze out some time to work on it. |
||
|
||
if consolidated is None: | ||
try: | ||
zarr_group = zarr.open_consolidated(store, **open_kwargs) | ||
except KeyError: | ||
warnings.warn( | ||
"Failed to open Zarr store with consolidated metadata, " | ||
"falling back to try reading non-consolidated metadata. " | ||
"This is typically much slower for opening a dataset. " | ||
"To silence this warning, consider:\n" | ||
"1. Consolidating metadata in this existing store with " | ||
"zarr.consolidate_metadata().\n" | ||
"2. Explicitly setting consolidated=False, to avoid trying " | ||
"to read consolidate metadata, or\n" | ||
"3. Explicitly setting consolidated=True, to raise an " | ||
"error in this case instead of falling back to try " | ||
"reading non-consolidated metadata.", | ||
RuntimeWarning, | ||
stacklevel=stacklevel, | ||
) | ||
zarr_group = zarr.open_group(store, **open_kwargs) | ||
elif consolidated: | ||
# TODO: an option to pass the metadata_key keyword | ||
zarr_group = zarr.open_consolidated(store, **open_kwargs) | ||
else: | ||
zarr_group = zarr.open_group(store, **open_kwargs) | ||
|
||
return cls( | ||
zarr_group, | ||
mode, | ||
consolidate_on_close, | ||
append_dim, | ||
write_region, | ||
safe_chunks, | ||
) | ||
|
||
def open_store_variable(self, name, zarr_array): | ||
|
||
try_nczarr = self._mode == "r" | ||
dimensions, attributes = zarr_backend._get_zarr_dims_and_attrs( | ||
zarr_array, zarr_backend.DIMENSION_KEY, try_nczarr | ||
) | ||
|
||
#### Changed from zarr array wrapper | ||
if name in dimensions: | ||
# we want indexed dimensions to be loaded eagerly | ||
# Right now we load in to device and then transfer to host | ||
# But these should be small-ish arrays | ||
# TODO: can we tell GDSStore to load as numpy array directly | ||
# not cupy array? | ||
array_wrapper = EagerCupyZarrArrayWrapper | ||
else: | ||
array_wrapper = CupyZarrArrayWrapper | ||
data = indexing.LazilyIndexedArray(array_wrapper(name, self)) | ||
|
||
attributes = dict(attributes) | ||
encoding = { | ||
"chunks": zarr_array.chunks, | ||
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks)), | ||
"compressor": zarr_array.compressor, | ||
"filters": zarr_array.filters, | ||
} | ||
# _FillValue needs to be in attributes, not encoding, so it will get | ||
# picked up by decode_cf | ||
if getattr(zarr_array, "fill_value") is not None: | ||
attributes["_FillValue"] = zarr_array.fill_value | ||
|
||
return Variable(dimensions, data, attributes, encoding) | ||
|
||
|
||
class KvikioBackendEntrypoint(ZarrBackendEntrypoint): | ||
available = has_kvikio | ||
|
||
# disabled by default | ||
# We need to provide this because of the subclassing from | ||
# ZarrBackendEntrypoint | ||
def guess_can_open(self, filename_or_obj): | ||
return False | ||
|
||
def open_dataset( | ||
self, | ||
filename_or_obj, | ||
mask_and_scale=True, | ||
decode_times=True, | ||
concat_characters=True, | ||
decode_coords=True, | ||
drop_variables=None, | ||
use_cftime=None, | ||
decode_timedelta=None, | ||
group=None, | ||
mode="r", | ||
synchronizer=None, | ||
consolidated=None, | ||
chunk_store=None, | ||
storage_options=None, | ||
stacklevel=3, | ||
): | ||
|
||
filename_or_obj = _normalize_path(filename_or_obj) | ||
store = GDSZarrStore.open_group( | ||
filename_or_obj, | ||
group=group, | ||
mode=mode, | ||
synchronizer=synchronizer, | ||
consolidated=consolidated, | ||
consolidate_on_close=False, | ||
chunk_store=chunk_store, | ||
storage_options=storage_options, | ||
stacklevel=stacklevel + 1, | ||
) | ||
|
||
store_entrypoint = StoreBackendEntrypoint() | ||
with close_on_error(store): | ||
ds = store_entrypoint.open_dataset( | ||
store, | ||
mask_and_scale=mask_and_scale, | ||
decode_times=decode_times, | ||
concat_characters=concat_characters, | ||
decode_coords=decode_coords, | ||
drop_variables=drop_variables, | ||
use_cftime=use_cftime, | ||
decode_timedelta=decode_timedelta, | ||
) | ||
return ds |
Uh oh!
There was an error while loading. Please reload this page.