Skip to content

Commit 8f0d9e5

Browse files
crusaderkyshoyer
authored andcommitted
WIP: more annotations (pydata#3090)
* Typing hints for Dataset.to_netcdf * type annotations * poke codecov
1 parent ae69079 commit 8f0d9e5

File tree

3 files changed

+109
-49
lines changed

3 files changed

+109
-49
lines changed

xarray/backends/api.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,29 @@
44
from io import BytesIO
55
from numbers import Number
66
from pathlib import Path
7-
import re
7+
from typing import Callable, Dict, Hashable, Iterable, Mapping, Tuple, Union
88

99
import numpy as np
10-
import pandas as pd
1110

1211
from .. import Dataset, DataArray, backends, conventions, coding
1312
from ..core import indexing
1413
from .. import auto_combine
15-
from ..core.combine import (combine_by_coords, _nested_combine,
16-
_infer_concat_order_from_positions)
14+
from ..core.combine import (
15+
combine_by_coords,
16+
_nested_combine,
17+
_infer_concat_order_from_positions
18+
)
19+
from ..core.pycompat import TYPE_CHECKING
1720
from ..core.utils import close_on_error, is_grib_path, is_remote_uri
18-
from ..core.variable import Variable
19-
from .common import ArrayWriter
21+
from .common import ArrayWriter, AbstractDataStore
2022
from .locks import _get_scheduler
21-
from ..coding.variables import safe_setitem, unpack_for_encoding
23+
24+
if TYPE_CHECKING:
25+
try:
26+
from dask.delayed import Delayed
27+
except ImportError:
28+
Delayed = None
29+
2230

2331
DATAARRAY_NAME = '__xarray_dataarray_name__'
2432
DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'
@@ -406,7 +414,7 @@ def maybe_decode_store(store, lock=False):
406414
if isinstance(filename_or_obj, Path):
407415
filename_or_obj = str(filename_or_obj)
408416

409-
if isinstance(filename_or_obj, backends.AbstractDataStore):
417+
if isinstance(filename_or_obj, AbstractDataStore):
410418
store = filename_or_obj
411419

412420
elif isinstance(filename_or_obj, str):
@@ -805,14 +813,25 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied',
805813
return combined
806814

807815

808-
WRITEABLE_STORES = {'netcdf4': backends.NetCDF4DataStore.open,
809-
'scipy': backends.ScipyDataStore,
810-
'h5netcdf': backends.H5NetCDFStore}
811-
812-
813-
def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
814-
engine=None, encoding=None, unlimited_dims=None, compute=True,
815-
multifile=False):
816+
WRITEABLE_STORES = {
817+
'netcdf4': backends.NetCDF4DataStore.open,
818+
'scipy': backends.ScipyDataStore,
819+
'h5netcdf': backends.H5NetCDFStore
820+
} # type: Dict[str, Callable]
821+
822+
823+
def to_netcdf(
824+
dataset: Dataset,
825+
path_or_file=None,
826+
mode: str = 'w',
827+
format: str = None,
828+
group: str = None,
829+
engine: str = None,
830+
encoding: Mapping = None,
831+
unlimited_dims: Iterable[Hashable] = None,
832+
compute: bool = True,
833+
multifile: bool = False
834+
) -> Union[Tuple[ArrayWriter, AbstractDataStore], bytes, 'Delayed', None]:
816835
"""This function creates an appropriate datastore for writing a dataset to
817836
disk as a netCDF file
818837
@@ -872,8 +891,12 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
872891

873892
if unlimited_dims is None:
874893
unlimited_dims = dataset.encoding.get('unlimited_dims', None)
875-
if isinstance(unlimited_dims, str):
876-
unlimited_dims = [unlimited_dims]
894+
if unlimited_dims is not None:
895+
if (isinstance(unlimited_dims, str)
896+
or not isinstance(unlimited_dims, Iterable)):
897+
unlimited_dims = [unlimited_dims]
898+
else:
899+
unlimited_dims = list(unlimited_dims)
877900

878901
writer = ArrayWriter()
879902

@@ -902,6 +925,7 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
902925
if not compute:
903926
import dask
904927
return dask.delayed(_finalize_store)(writes, store)
928+
return None
905929

906930

907931
def dump_to_store(dataset, store, writer=None, encoder=None,

xarray/core/dataarray.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import warnings
44
from collections import OrderedDict
5-
from distutils.version import LooseVersion
5+
from numbers import Number
66
from typing import (Any, Callable, Dict, Hashable, Iterable, List, Mapping,
77
Optional, Sequence, Tuple, Union, cast)
88

@@ -871,13 +871,19 @@ def chunks(self) -> Optional[Tuple[Tuple[int, ...], ...]]:
871871
"""
872872
return self.variable.chunks
873873

874-
def chunk(self, chunks: Union[
875-
None, int, Tuple[int, ...], Tuple[Tuple[int, ...], ...],
876-
Mapping[Hashable, Union[None, int, Tuple[int, ...]]],
877-
] = None,
878-
name_prefix: str = 'xarray-',
879-
token: Optional[str] = None,
880-
lock: bool = False) -> 'DataArray':
874+
def chunk(
875+
self,
876+
chunks: Union[
877+
None,
878+
Number,
879+
Tuple[Number, ...],
880+
Tuple[Tuple[Number, ...], ...],
881+
Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]],
882+
] = None,
883+
name_prefix: str = 'xarray-',
884+
token: Optional[str] = None,
885+
lock: bool = False
886+
) -> 'DataArray':
881887
"""Coerce this array's data into a dask arrays with the given chunks.
882888
883889
If this variable is a non-dask array, it will be converted to dask
@@ -890,7 +896,7 @@ def chunk(self, chunks: Union[
890896
891897
Parameters
892898
----------
893-
chunks : int, tuple or dict, optional
899+
chunks : int, tuple or mapping, optional
894900
Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or
895901
``{'x': 5, 'y': 5}``.
896902
name_prefix : str, optional
@@ -905,7 +911,7 @@ def chunk(self, chunks: Union[
905911
-------
906912
chunked : xarray.DataArray
907913
"""
908-
if isinstance(chunks, (list, tuple)):
914+
if isinstance(chunks, (tuple, list)):
909915
chunks = dict(zip(self.dims, chunks))
910916

911917
ds = self._to_temp_dataset().chunk(chunks, name_prefix=name_prefix,

xarray/core/dataset.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from collections import OrderedDict, defaultdict
66
from distutils.version import LooseVersion
77
from numbers import Number
8+
from pathlib import Path
89
from typing import (Any, Dict, Hashable, Iterable, Iterator, List,
9-
Mapping, Optional, Sequence, Set, Tuple, Union, cast)
10+
Mapping, MutableMapping, Optional, Sequence, Set, Tuple,
11+
Union, cast)
1012

1113
import numpy as np
1214
import pandas as pd
@@ -45,8 +47,12 @@
4547
pass
4648

4749
if TYPE_CHECKING:
48-
from ..backends import AbstractDataStore
50+
from ..backends import AbstractDataStore, ZarrStore
4951
from .dataarray import DataArray
52+
try:
53+
from dask.delayed import Delayed
54+
except ImportError:
55+
Delayed = None
5056

5157

5258
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
@@ -1309,9 +1315,17 @@ def dump_to_store(self, store: 'AbstractDataStore', **kwargs) -> None:
13091315
# with to_netcdf()
13101316
dump_to_store(self, store, **kwargs)
13111317

1312-
def to_netcdf(self, path=None, mode='w', format=None, group=None,
1313-
engine=None, encoding=None, unlimited_dims=None,
1314-
compute=True):
1318+
def to_netcdf(
1319+
self,
1320+
path=None,
1321+
mode: str = 'w',
1322+
format: str = None,
1323+
group: str = None,
1324+
engine: str = None,
1325+
encoding: Mapping = None,
1326+
unlimited_dims: Iterable[Hashable] = None,
1327+
compute: bool = True,
1328+
) -> Union[bytes, 'Delayed', None]:
13151329
"""Write dataset contents to a netCDF file.
13161330
13171331
Parameters
@@ -1366,7 +1380,7 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None,
13661380
This allows using any compression plugin installed in the HDF5
13671381
library, e.g. LZF.
13681382
1369-
unlimited_dims : sequence of str, optional
1383+
unlimited_dims : iterable of hashable, optional
13701384
Dimension(s) that should be serialized as unlimited dimensions.
13711385
By default, no dimensions are treated as unlimited dimensions.
13721386
Note that unlimited_dims may also be set via
@@ -1383,9 +1397,17 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None,
13831397
unlimited_dims=unlimited_dims,
13841398
compute=compute)
13851399

1386-
def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None,
1387-
encoding=None, compute=True, consolidated=False,
1388-
append_dim=None):
1400+
def to_zarr(
1401+
self,
1402+
store: Union[MutableMapping, str, Path] = None,
1403+
mode: str = 'w-',
1404+
synchronizer=None,
1405+
group: str = None,
1406+
encoding: Mapping = None,
1407+
compute: bool = True,
1408+
consolidated: bool = False,
1409+
append_dim: Hashable = None
1410+
) -> 'ZarrStore':
13891411
"""Write dataset contents to a zarr group.
13901412
13911413
.. note:: Experimental
@@ -1394,15 +1416,15 @@ def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None,
13941416
13951417
Parameters
13961418
----------
1397-
store : MutableMapping or str, optional
1419+
store : MutableMapping, str or Path, optional
13981420
Store or path to directory in file system.
13991421
mode : {'w', 'w-', 'a'}
14001422
Persistence mode: 'w' means create (overwrite if exists);
14011423
'w-' means create (fail if exists);
14021424
'a' means append (create if does not exist).
14031425
synchronizer : object, optional
14041426
Array synchronizer
1405-
group : str, obtional
1427+
group : str, optional
14061428
Group path. (a.k.a. `path` in zarr terminology.)
14071429
encoding : dict, optional
14081430
Nested dictionary with variable names as keys and dictionaries of
@@ -1414,7 +1436,7 @@ def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None,
14141436
consolidated: bool, optional
14151437
If True, apply zarr's `consolidate_metadata` function to the store
14161438
after writing.
1417-
append_dim: str, optional
1439+
append_dim: hashable, optional
14181440
If mode='a', the dimension on which the data will be appended.
14191441
14201442
References
@@ -1432,10 +1454,10 @@ def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None,
14321454
group=group, encoding=encoding, compute=compute,
14331455
consolidated=consolidated, append_dim=append_dim)
14341456

1435-
def __repr__(self):
1457+
def __repr__(self) -> str:
14361458
return formatting.dataset_repr(self)
14371459

1438-
def info(self, buf=None):
1460+
def info(self, buf=None) -> None:
14391461
"""
14401462
Concise summary of a Dataset variables and attributes.
14411463
@@ -1448,7 +1470,6 @@ def info(self, buf=None):
14481470
pandas.DataFrame.assign
14491471
ncdump: netCDF's ncdump
14501472
"""
1451-
14521473
if buf is None: # pragma: no cover
14531474
buf = sys.stdout
14541475

@@ -1473,11 +1494,11 @@ def info(self, buf=None):
14731494
buf.write('\n'.join(lines))
14741495

14751496
@property
1476-
def chunks(self):
1497+
def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]:
14771498
"""Block dimensions for this dataset's data or None if it's not a dask
14781499
array.
14791500
"""
1480-
chunks = {}
1501+
chunks = {} # type: Dict[Hashable, Tuple[int, ...]]
14811502
for v in self.variables.values():
14821503
if v.chunks is not None:
14831504
for dim, c in zip(v.dims, v.chunks):
@@ -1486,8 +1507,17 @@ def chunks(self):
14861507
chunks[dim] = c
14871508
return Frozen(SortedKeysDict(chunks))
14881509

1489-
def chunk(self, chunks=None, name_prefix='xarray-', token=None,
1490-
lock=False):
1510+
def chunk(
1511+
self,
1512+
chunks: Union[
1513+
None,
1514+
Number,
1515+
Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]]
1516+
] = None,
1517+
name_prefix: str = 'xarray-',
1518+
token: str = None,
1519+
lock: bool = False
1520+
) -> 'Dataset':
14911521
"""Coerce all arrays in this dataset into dask arrays with the given
14921522
chunks.
14931523
@@ -1500,7 +1530,7 @@ def chunk(self, chunks=None, name_prefix='xarray-', token=None,
15001530
15011531
Parameters
15021532
----------
1503-
chunks : int or dict, optional
1533+
chunks : int or mapping, optional
15041534
Chunk sizes along each dimension, e.g., ``5`` or
15051535
``{'x': 5, 'y': 5}``.
15061536
name_prefix : str, optional
@@ -1526,7 +1556,7 @@ def chunk(self, chunks=None, name_prefix='xarray-', token=None,
15261556
chunks = dict.fromkeys(self.dims, chunks)
15271557

15281558
if chunks is not None:
1529-
bad_dims = [d for d in chunks if d not in self.dims]
1559+
bad_dims = chunks.keys() - self.dims.keys()
15301560
if bad_dims:
15311561
raise ValueError('some chunks keys are not dimensions on this '
15321562
'object: %s' % bad_dims)

0 commit comments

Comments
 (0)