Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def to_dataset_dict(
storage_options: typing.Dict[pydantic.StrictStr, typing.Any] = None,
progressbar: pydantic.StrictBool = None,
aggregate: pydantic.StrictBool = None,
skip_on_error: pydantic.StrictBool = False,
**kwargs,
) -> typing.Dict[str, xr.Dataset]:
"""
Expand All @@ -460,6 +461,8 @@ def to_dataset_dict(
when loading assets into :py:class:`~xarray.Dataset`.
aggregate : bool, optional
If False, no aggregation will be done.
skip_on_error : bool, optional
If True, skip datasets that cannot be loaded and/or variables we are unable to derive.

Returns
-------
Expand Down Expand Up @@ -516,9 +519,7 @@ def to_dataset_dict(

xarray_open_kwargs = xarray_open_kwargs or {}
xarray_combine_by_coords_kwargs = xarray_combine_by_coords_kwargs or {}

cdf_kwargs = kwargs.get('cdf_kwargs')
zarr_kwargs = kwargs.get('zarr_kwargs')
cdf_kwargs, zarr_kwargs = kwargs.get('cdf_kwargs'), kwargs.get('zarr_kwargs')

if cdf_kwargs or zarr_kwargs:
warnings.warn(
Expand All @@ -527,10 +528,10 @@ def to_dataset_dict(
DeprecationWarning,
stacklevel=2,
)
if cdf_kwargs:
xarray_open_kwargs.update(cdf_kwargs)
if zarr_kwargs:
xarray_open_kwargs.update(zarr_kwargs)
if cdf_kwargs:
xarray_open_kwargs.update(cdf_kwargs)
if zarr_kwargs:
xarray_open_kwargs.update(zarr_kwargs)

source_kwargs = dict(
xarray_open_kwargs=xarray_open_kwargs,
Expand All @@ -543,41 +544,43 @@ def to_dataset_dict(
if aggregate is not None and not aggregate:
self = deepcopy(self)
self.esmcat.aggregation_control.groupby_attrs = []

if progressbar is not None:
self.progressbar = progressbar

if self.progressbar:
print(
f"""\n--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{self.key_template}'"""
)

def _load_source(key, source):
return key, source.to_dask()

sources = {key: source(**source_kwargs) for key, source in self.items()}
progress, total = None, None
if self.progressbar:
total = len(sources)
progress = progress_bar(range(total))

datasets = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=dask.system.CPU_COUNT) as executor:
future_tasks = [
executor.submit(_load_source, key, source) for key, source in sources.items()
]
for i, task in enumerate(concurrent.futures.as_completed(future_tasks)):
key, ds = task.result()
datasets[key] = ds
if self.progressbar:
progress.update(i)
if self.progressbar:
progress.update(total)
gen = progress_bar(
concurrent.futures.as_completed(future_tasks), total=len(sources)
)
else:
gen = concurrent.futures.as_completed(future_tasks)
for task in gen:
try:
key, ds = task.result()
datasets[key] = ds
except Exception as exc:
if not skip_on_error:
raise exc
self.datasets = self._create_derived_variables(datasets, skip_on_error)
return self.datasets

def _create_derived_variables(self, datasets, skip_on_error):
if len(self.derivedcat) > 0:
datasets = self.derivedcat.update_datasets(
datasets=datasets,
variable_key_name=self.esmcat.aggregation_control.variable_column_name,
skip_on_error=skip_on_error,
)
self.datasets = datasets
return self.datasets
return datasets


def _load_source(key, source):
return key, source.to_dask()
35 changes: 22 additions & 13 deletions intake_esm/derived.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import tlz
import xarray as xr

from .utils import INTAKE_ESM_ATTRS_PREFIX


class DerivedVariableError(Exception):
pass
Expand All @@ -30,17 +28,16 @@ def dependent_variables(self, variable_key_name: str) -> typing.List[pydantic.St
"""Return a list of dependent variables for a given variable"""
return self.query[variable_key_name]

def __call__(self, *args, **kwargs) -> xr.Dataset:
def __call__(self, *args, variable_key_name: str = None, **kwargs) -> xr.Dataset:
"""Call the function and return the result"""
try:
ds = self.func(*args, **kwargs)
ds[self.variable].attrs[
f'{INTAKE_ESM_ATTRS_PREFIX}/derivation'
] = f'dependent_variables: {self.dependent_variables}'
return ds
return self.func(*args, **kwargs)
except Exception as exc:
dependent_variables = (
self.dependent_variables(variable_key_name) if variable_key_name else []
)
raise DerivedVariableError(
f'Unable to derived variable: {self.variable} with dependent: {self.dependent_variables} using args:{args} and kwargs:{kwargs}'
f'Unable to derived variable: {self.variable} with dependent: {dependent_variables} using args:{args} and kwargs:{kwargs}'
) from exc


Expand Down Expand Up @@ -158,7 +155,11 @@ def search(self, variable: typing.Union[str, typing.List[str]]) -> 'DerivedVaria
return reg

def update_datasets(
self, *, datasets: typing.Dict[str, xr.Dataset], variable_key_name: str
self,
*,
datasets: typing.Dict[str, xr.Dataset],
variable_key_name: str,
skip_on_error: bool = False,
) -> typing.Dict[str, xr.Dataset]:
"""Given a dictionary of datasets, return a dictionary of datasets with the derived variables

Expand All @@ -168,6 +169,8 @@ def update_datasets(
A dictionary of datasets to apply the derived variables to.
variable_key_name : str
The name of the variable key used in the derived variable query
skip_on_error : bool, optional
If True, skip variables that fail variable derivation.

Returns
-------
Expand All @@ -180,9 +183,15 @@ def update_datasets(
if set(dataset.variables).intersection(
derived_variable.dependent_variables(variable_key_name)
):
# Assumes all dependent variables are in the same dataset
# TODO: Make this more robust to support datasets with variables from different datasets
datasets[dset_key] = derived_variable(dataset)
try:
# Assumes all dependent variables are in the same dataset
# TODO: Make this more robust to support datasets with variables from different datasets
datasets[dset_key] = derived_variable(
dataset, variable_key_name=variable_key_name
)
except Exception as exc:
if not skip_on_error:
raise exc
return datasets


Expand Down
90 changes: 53 additions & 37 deletions intake_esm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from .utils import INTAKE_ESM_ATTRS_PREFIX, INTAKE_ESM_DATASET_KEY, INTAKE_ESM_VARS_KEY


class ESMDataSourceError(Exception):
pass


def _get_xarray_open_kwargs(data_format, xarray_open_kwargs=None):
xarray_open_kwargs = (xarray_open_kwargs or {}).copy()
_default_open_kwargs = {
Expand Down Expand Up @@ -155,43 +159,55 @@ def _get_schema(self) -> Schema:
def _open_dataset(self):
"""Open dataset with xarray"""

datasets = [
_open_dataset(
record[self.path_column_name],
record[self.variable_column_name],
xarray_open_kwargs=self.xarray_open_kwargs,
preprocess=self.preprocess,
expand_dims={
agg.attribute_name: [record[agg.attribute_name]]
for agg in self.aggregations
if agg.type.value == 'join_new'
},
requested_variables=self.requested_variables,
additional_attrs=record.to_dict(),
)
for _, record in self.df.iterrows()
]

datasets = dask.compute(*datasets)
if len(datasets) == 1:
self._ds = datasets[0]
else:
datasets = sorted(
datasets,
key=lambda ds: tuple(
f'{INTAKE_ESM_ATTRS_PREFIX}/{agg.attribute_name}' for agg in self.aggregations
),
)
with dask.config.set(
{'scheduler': 'single-threaded', 'array.slicing.split_large_chunks': True}
): # Use single-threaded scheduler
datasets = [
ds.set_coords(set(ds.variables) - set(ds.attrs[INTAKE_ESM_VARS_KEY]))
for ds in datasets
]
self._ds = xr.combine_by_coords(datasets, **self.xarray_combine_by_coords_kwargs)

self._ds.attrs[INTAKE_ESM_DATASET_KEY] = self.key
try:

datasets = [
_open_dataset(
record[self.path_column_name],
record[self.variable_column_name],
xarray_open_kwargs=self.xarray_open_kwargs,
preprocess=self.preprocess,
expand_dims={
agg.attribute_name: [record[agg.attribute_name]]
for agg in self.aggregations
if agg.type.value == 'join_new'
},
requested_variables=self.requested_variables,
additional_attrs=record.to_dict(),
)
for _, record in self.df.iterrows()
]

datasets = dask.compute(*datasets)
if len(datasets) == 1:
self._ds = datasets[0]
else:
datasets = sorted(
datasets,
key=lambda ds: tuple(
f'{INTAKE_ESM_ATTRS_PREFIX}/{agg.attribute_name}'
for agg in self.aggregations
),
)
with dask.config.set(
{'scheduler': 'single-threaded', 'array.slicing.split_large_chunks': True}
): # Use single-threaded scheduler
datasets = [
ds.set_coords(set(ds.variables) - set(ds.attrs[INTAKE_ESM_VARS_KEY]))
for ds in datasets
]
self._ds = xr.combine_by_coords(
datasets, **self.xarray_combine_by_coords_kwargs
)

self._ds.attrs[INTAKE_ESM_DATASET_KEY] = self.key

except Exception as exc:
raise ESMDataSourceError(
f"""Failed to load dataset with key='{self.key}'
You can use `cat['{self.key}'].df` to inspect the assets/files for this key.
"""
) from exc

def to_dask(self):
"""Return xarray object (which will have chunks)"""
Expand Down
16 changes: 16 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,22 @@ def test_to_dataset_dict_w_preprocess_error():
cat.to_dataset_dict(preprocess='foo')


def test_to_dataset_dict_skip_error():
cat = intake.open_esm_datastore(catalog_dict_records)
with pytest.raises(intake_esm.source.ESMDataSourceError):
dsets = cat.to_dataset_dict(
xarray_open_kwargs={'backend_kwargsd': {'storage_options': {'anon': True}}},
skip_on_error=False,
)

dsets = cat.to_dataset_dict(
xarray_open_kwargs={'backend_kwargsd': {'storage_options': {'anon': True}}},
skip_on_error=True,
)

assert len(dsets.keys()) == 0


def test_to_dataset_dict_with_registry():

registry = intake_esm.DerivedVariableRegistry()
Expand Down
28 changes: 27 additions & 1 deletion tests/test_derived.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,37 @@ def func(ds):
ds['FOO'] = ds.air // 100
return ds

dsets = dvr.update_datasets(datasets={'test': ds}, variable_key_name='variable')
dsets = dvr.update_datasets(datasets={'test': ds.copy()}, variable_key_name='variable')
assert 'test' in dsets
assert 'FOO' in dsets['test']
assert isinstance(dsets['test']['FOO'], xr.DataArray)


def test_registry_derive_variables_error():
ds = xr.tutorial.open_dataset('air_temperature')
dvr = DerivedVariableRegistry()

@dvr.register(variable='FOO', query={'variable': 'air'})
def func(ds):
ds['FOO'] = ds.air // 100
return ds

# Test for errors/ invalid inputs, wrong return type
with pytest.raises(DerivedVariableError):
dvr['FOO']({})

@dvr.register(variable='FOO', query={'variable': 'air'})
def funcb(ds):
ds['FOO'] = 1 / 0
return ds

dsets = dvr.update_datasets(
datasets={'test': ds.copy()}, variable_key_name='variable', skip_on_error=True
)

assert 'FOO' not in dsets['test']

with pytest.raises(DerivedVariableError):
dsets = dvr.update_datasets(
datasets={'test': ds.copy()}, variable_key_name='variable', skip_on_error=False
)