Skip to content

Commit 6c5d272

Browse files
committed
Merge pull request #251 from xray/encoding-improvements
Encoding improvements
2 parents 8839723 + c748e0f commit 6c5d272

File tree

5 files changed

+132
-118
lines changed

5 files changed

+132
-118
lines changed

xray/backends/netCDF4_.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,24 +86,19 @@ class NetCDF4DataStore(AbstractWritableDataStore):
8686
This store supports NetCDF3, NetCDF4 and OpenDAP datasets.
8787
"""
8888
def __init__(self, filename, mode='r', clobber=True, diskless=False,
89-
persist=False, format='NETCDF4', group=None,
90-
*args, **kwdargs):
89+
persist=False, format='NETCDF4', group=None):
9190
import netCDF4 as nc4
9291
ds = nc4.Dataset(filename, mode=mode, clobber=clobber,
9392
diskless=diskless, persist=persist,
9493
format=format)
9594
self.ds = _nc4_group(ds, group)
9695
self.format = format
9796
self._filename = filename
98-
self._encoder_args = args
99-
self._encoder_kwdargs = kwdargs
10097

10198
def store(self, variables, attributes):
10299
# All NetCDF files get CF encoded by default, without this attempting
103100
# to write times, for example, would fail.
104-
cf_variables, cf_attrs = cf_encoder(variables, attributes,
105-
*self._encoder_args,
106-
**self._encoder_kwdargs)
101+
cf_variables, cf_attrs = cf_encoder(variables, attributes)
107102
AbstractWritableDataStore.store(self, cf_variables, cf_attrs)
108103

109104
def open_store_variable(self, var):

xray/backends/scipy_.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ class ScipyDataStore(AbstractWritableDataStore):
3434
3535
It only supports the NetCDF3 file-format.
3636
"""
37-
def __init__(self, filename_or_obj, mode='r', mmap=None,
38-
version=1, *args, **kwdargs):
37+
def __init__(self, filename_or_obj, mode='r', mmap=None, version=1):
3938
import scipy
4039
if mode != 'r' and scipy.__version__ < '0.13':
4140
warnings.warn('scipy %s detected; '
@@ -53,15 +52,11 @@ def __init__(self, filename_or_obj, mode='r', mmap=None,
5352
filename_or_obj = BytesIO(filename_or_obj)
5453
self.ds = scipy.io.netcdf.netcdf_file(
5554
filename_or_obj, mode=mode, mmap=mmap, version=version)
56-
self._encoder_args = args
57-
self._encoder_kwdargs = kwdargs
5855

5956
def store(self, variables, attributes):
6057
# All Scipy objects get CF encoded by default, without this attempting
6158
# to write times, for example, would fail.
62-
cf_variables, cf_attrs = cf_encoder(variables, attributes,
63-
*self._encoder_args,
64-
**self._encoder_kwdargs)
59+
cf_variables, cf_attrs = cf_encoder(variables, attributes)
6560
AbstractWritableDataStore.store(self, cf_variables, cf_attrs)
6661

6762
def open_store_variable(self, var):

xray/conventions.py

Lines changed: 112 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import numpy as np
23
import pandas as pd
34
import warnings
@@ -372,6 +373,73 @@ def pop_to(source, dest, key, default=None):
372373
return value
373374

374375

376+
def _var_as_tuple(var):
377+
return var.dims, var.values, var.attrs.copy(), var.encoding.copy()
378+
379+
380+
def maybe_encode_datetime(var):
381+
if (np.issubdtype(var.dtype, np.datetime64)
382+
or (var.dtype.kind == 'O'
383+
and isinstance(var.values.flat[0], datetime))):
384+
385+
dims, values, attrs, encoding = _var_as_tuple(var)
386+
if 'units' in attrs or 'calendar' in attrs:
387+
raise ValueError(
388+
"Failed hard to prevent overwriting 'units' or 'calendar'")
389+
390+
(values, units, calendar) = encode_cf_datetime(
391+
values, encoding.pop('units', None), encoding.pop('calendar', None))
392+
attrs['units'] = units
393+
attrs['calendar'] = calendar
394+
var = Variable(dims, values, attrs, encoding)
395+
return var
396+
397+
398+
def maybe_encode_offset_and_scale(var, needs_copy=True):
399+
if any(k in var.encoding for k in ['add_offset', 'scale_factor']):
400+
dims, values, attrs, encoding = _var_as_tuple(var)
401+
values = np.array(values, dtype=float, copy=needs_copy)
402+
needs_copy = False
403+
if 'add_offset' in encoding:
404+
values -= pop_to(encoding, attrs, 'add_offset')
405+
if 'scale_factor' in encoding:
406+
values /= pop_to(encoding, attrs, 'scale_factor')
407+
var = Variable(dims, values, attrs, encoding)
408+
return var, needs_copy
409+
410+
411+
def maybe_encode_fill_value(var, needs_copy=True):
412+
# replace NaN with the fill value
413+
if '_FillValue' in var.encoding:
414+
dims, values, attrs, encoding = _var_as_tuple(var)
415+
fill_value = pop_to(encoding, attrs, '_FillValue')
416+
if not pd.isnull(fill_value):
417+
missing = pd.isnull(values)
418+
if missing.any():
419+
if needs_copy:
420+
values = values.copy()
421+
needs_copy = False
422+
values[missing] = fill_value
423+
var = Variable(dims, values, attrs, encoding)
424+
return var, needs_copy
425+
426+
427+
def maybe_encode_dtype(var, needs_copy=True):
428+
if 'dtype' in var.encoding:
429+
dims, values, attrs, encoding = _var_as_tuple(var)
430+
dtype = np.dtype(encoding.pop('dtype'))
431+
if dtype.kind != 'O':
432+
if np.issubdtype(dtype, int):
433+
out = np.empty_like(values) if needs_copy else values
434+
np.around(values, out=out)
435+
if dtype == 'S1' and values.dtype != 'S1':
436+
values = string_to_char(np.asarray(values, 'S'))
437+
dims = dims + ('string%s' % values.shape[-1],)
438+
values = np.asarray(values, dtype=dtype)
439+
var = Variable(dims, values, attrs, encoding)
440+
return var
441+
442+
375443
def _infer_dtype(array):
376444
"""Given an object array with no missing values, infer its dtype from its
377445
first element
@@ -390,7 +458,36 @@ def _infer_dtype(array):
390458
return dtype
391459

392460

393-
def encode_cf_variable(var):
461+
def ensure_dtype_not_object(var):
462+
# TODO: move this from conventions to backends? (it's not CF related)
463+
if var.dtype.kind == 'O':
464+
dims, values, attrs, encoding = _var_as_tuple(var)
465+
missing = pd.isnull(values)
466+
if missing.any():
467+
non_missing_values = values[~missing]
468+
inferred_dtype = _infer_dtype(non_missing_values)
469+
470+
if inferred_dtype.kind in ['S', 'U']:
471+
# There is no safe bit-pattern for NA in typical binary string
472+
# formats, we so can't set a fill_value. Unfortunately, this
473+
# means we won't be able to restore string arrays with missing
474+
# values.
475+
fill_value = ''
476+
else:
477+
# insist on using float for numeric values
478+
if not np.issubdtype(inferred_dtype, float):
479+
inferred_dtype = np.dtype(float)
480+
fill_value = np.nan
481+
482+
values = np.array(values, dtype=inferred_dtype, copy=True)
483+
values[missing] = fill_value
484+
else:
485+
values = np.asarray(values, dtype=_infer_dtype(values))
486+
var = Variable(dims, values, attrs, encoding)
487+
return var
488+
489+
490+
def encode_cf_variable(var, needs_copy=True):
394491
"""
395492
Converts an Variable into an Variable which follows some
396493
of the CF conventions:
@@ -410,86 +507,12 @@ def encode_cf_variable(var):
410507
out : xray.Variable
411508
A variable which has been encoded as described above.
412509
"""
413-
dimensions = var.dims
414-
data = var.values
415-
attributes = var.attrs.copy()
416-
encoding = var.encoding.copy()
417-
418-
# convert datetimes into numbers
419-
if (np.issubdtype(data.dtype, np.datetime64)
420-
or (data.dtype.kind == 'O'
421-
and isinstance(data.reshape(-1)[0], datetime))):
422-
if 'units' in attributes or 'calendar' in attributes:
423-
raise ValueError(
424-
"Failed hard to prevent overwriting 'units' or 'calendar'")
425-
(data, units, calendar) = encode_cf_datetime(
426-
data, encoding.pop('units', None), encoding.pop('calendar', None))
427-
attributes['units'] = units
428-
attributes['calendar'] = calendar
429-
430-
# unscale/mask
431-
if any(k in encoding for k in ['add_offset', 'scale_factor']):
432-
data = np.array(data, dtype=float, copy=True)
433-
if 'add_offset' in encoding:
434-
data -= pop_to(encoding, attributes, 'add_offset')
435-
if 'scale_factor' in encoding:
436-
data /= pop_to(encoding, attributes, 'scale_factor')
437-
438-
# replace NaN with the fill value
439-
if '_FillValue' in encoding:
440-
fill_value = pop_to(encoding, attributes, '_FillValue')
441-
if not pd.isnull(fill_value):
442-
missing = pd.isnull(data)
443-
if missing.any():
444-
data = data.copy()
445-
data[missing] = fill_value
446-
447-
# replace NaN with the missing_value
448-
if 'missing_value' in encoding:
449-
missing_value = pop_to(encoding, attributes, 'missing_value')
450-
if not pd.isnull(missing_value):
451-
missing = pd.isnull(data)
452-
if missing.any():
453-
data = data.copy()
454-
data[missing] = missing_value
455-
456-
# cast to encoded dtype
457-
if 'dtype' in encoding:
458-
dtype = np.dtype(encoding.pop('dtype'))
459-
if dtype.kind != 'O':
460-
if np.issubdtype(dtype, int):
461-
data = data.round()
462-
if dtype == 'S1' and data.dtype != 'S1':
463-
data = string_to_char(np.asarray(data, 'S'))
464-
dimensions = dimensions + ('string%s' % data.shape[-1],)
465-
data = np.asarray(data, dtype=dtype)
466-
467-
# infer a valid dtype if necessary
468-
# TODO: move this from conventions to backends (it's not CF related)
469-
if data.dtype.kind == 'O':
470-
missing = pd.isnull(data)
471-
if missing.any():
472-
non_missing_data = data[~missing]
473-
inferred_dtype = _infer_dtype(non_missing_data)
474-
475-
if inferred_dtype.kind in ['S', 'U']:
476-
# There is no safe bit-pattern for NA in typical binary string
477-
# formats, we so can't set a fill_value. Unfortunately, this
478-
# means we won't be able to restore string arrays with missing
479-
# values.
480-
fill_value = ''
481-
else:
482-
# insist on using float for numeric data
483-
if not np.issubdtype(inferred_dtype, float):
484-
inferred_dtype = np.dtype(float)
485-
fill_value = np.nan
486-
487-
data = np.array(data, dtype=inferred_dtype, copy=True)
488-
data[missing] = fill_value
489-
else:
490-
data = np.asarray(data, dtype=_infer_dtype(data))
491-
492-
return Variable(dimensions, data, attributes, encoding=encoding)
510+
var = maybe_encode_datetime(var)
511+
var, needs_copy = maybe_encode_offset_and_scale(var, needs_copy)
512+
var, needs_copy = maybe_encode_fill_value(var, needs_copy)
513+
var = maybe_encode_dtype(var, needs_copy)
514+
var = ensure_dtype_not_object(var)
515+
return var
493516

494517

495518
def decode_cf_variable(var, concat_characters=True, mask_and_scale=True,
@@ -539,15 +562,15 @@ def decode_cf_variable(var, concat_characters=True, mask_and_scale=True,
539562
data = CharToStringArray(data)
540563

541564
if mask_and_scale:
542-
# missing_value is deprecated, but we still want to support it.
543-
missing_value = pop_to(attributes, encoding, 'missing_value')
565+
if 'missing_value' in attributes:
566+
# missing_value is deprecated, but we still want to support it as
567+
# an alias for _FillValue.
568+
assert ('_FillValue' not in attributes
569+
or utils.equivalent(attributes['_FillValue'],
570+
attributes['missing_value']))
571+
attributes['_FillValue'] = attributes.pop('missing_value')
572+
544573
fill_value = pop_to(attributes, encoding, '_FillValue')
545-
# if missing_value is given but not fill_value we use missing_value
546-
if fill_value is None and missing_value is not None:
547-
fill_value = missing_value
548-
# if both were given we make sure they are the same.
549-
if fill_value is not None and missing_value is not None:
550-
assert fill_value == missing_value
551574
scale_factor = pop_to(attributes, encoding, 'scale_factor')
552575
add_offset = pop_to(attributes, encoding, 'add_offset')
553576
if ((fill_value is not None and not pd.isnull(fill_value))

xray/core/dataset.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,14 @@ def open_dataset(nc, decode_cf=True, mask_and_scale=True, decode_times=True,
8484
# If nc is a file-like object we read it using
8585
# the scipy.io.netcdf package
8686
store = backends.ScipyDataStore(nc, *args, **kwargs)
87-
decoder = conventions.cf_decoder if decode_cf else None
88-
return Dataset.load_store(store, decoder=decoder,
89-
mask_and_scale=mask_and_scale,
90-
decode_times=decode_times,
91-
concat_characters=concat_characters)
87+
if decode_cf:
88+
decoder = functools.partial(conventions.cf_decoder,
89+
mask_and_scale=mask_and_scale,
90+
decode_times=decode_times,
91+
concat_characters=concat_characters)
92+
else:
93+
decoder = None
94+
return Dataset.load_store(store, decoder=decoder)
9295

9396

9497
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
@@ -399,14 +402,13 @@ def _set_init_vars_and_dims(self, vars, coords):
399402
check_coord_names=False)
400403

401404
@classmethod
402-
def load_store(cls, store, decoder=None, *args, **kwdargs):
405+
def load_store(cls, store, decoder=None):
403406
"""Create a new dataset from the contents of a backends.*DataStore
404407
object
405408
"""
406409
variables, attributes = store.load()
407410
if decoder:
408-
variables, attributes = decoder(variables, attributes,
409-
*args, **kwdargs)
411+
variables, attributes = decoder(variables, attributes)
410412
obj = cls(variables, attrs=attributes)
411413
obj._file_obj = store
412414
return obj
@@ -785,13 +787,11 @@ def reset_coords(self, names=None, drop=False, inplace=False):
785787
del obj._arrays[name]
786788
return obj
787789

788-
def dump_to_store(self, store, encoder=None,
789-
*args, **kwdargs):
790+
def dump_to_store(self, store, encoder=None):
790791
"""Store dataset contents to a backends.*DataStore object."""
791792
variables, attributes = self, self.attrs
792793
if encoder:
793-
variables, attributes = encoder(variables, attributes,
794-
*args, **kwdargs)
794+
variables, attributes = encoder(variables, attributes)
795795
store.store(variables, attributes)
796796
store.sync()
797797

xray/test/test_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import numpy as np
99
import pandas as pd
1010

11-
from xray import align, concat, backends, Dataset, DataArray, Variable
11+
from xray import (align, concat, conventions, backends, Dataset, DataArray,
12+
Variable)
1213
from xray.core import indexing, utils
1314
from xray.core.pycompat import iteritems, OrderedDict
1415

@@ -1020,8 +1021,8 @@ def test_lazy_load(self):
10201021
store = InaccessibleVariableDataStore()
10211022
create_test_data().dump_to_store(store)
10221023

1023-
for decode_cf in [False, True]:
1024-
ds = Dataset.load_store(store, decode_cf=decode_cf)
1024+
for decoder in [None, conventions.cf_decoder]:
1025+
ds = Dataset.load_store(store, decoder=decoder)
10251026
with self.assertRaises(UnexpectedDataAccess):
10261027
ds.load_data()
10271028
with self.assertRaises(UnexpectedDataAccess):

0 commit comments

Comments
 (0)