Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,9 +1315,11 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:

units = encoding.pop("units", None)
calendar = encoding.pop("calendar", None)
dtype = encoding.get("dtype", None)
dtype = encoding.pop("dtype", None)
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)

# if no dtype is provided, preserve data.dtype in encoding
if dtype is None:
safe_setitem(encoding, "dtype", data.dtype, name=name)
safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)

Expand Down Expand Up @@ -1369,8 +1371,16 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)

# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
dtype = None
if "add_offset" in encoding or "scale_factor" in encoding:
encoding.pop("dtype")
dtype = data.dtype if data.dtype.kind == "f" else "float64"

data, units = encode_cf_timedelta(
data, encoding.pop("units", None), encoding.get("dtype", None)
data, encoding.pop("units", None), encoding.get("dtype", dtype)
)
safe_setitem(attrs, "units", units, name=name)

Expand Down
63 changes: 55 additions & 8 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from build.lib.xarray.coding.times import CFDatetimeCoder, CFTimedeltaCoder

from xarray.core import dtypes, duck_array_ops, indexing
from xarray.core.variable import Variable
Expand Down Expand Up @@ -234,6 +235,8 @@ def _apply_mask(

def _is_time_like(units):
# test for time-like
# return "datetime" for datetetime-like
# return "timedelta" for timedelta-like
if units is None:
return False
time_strings = [
Expand All @@ -255,9 +258,9 @@ def _is_time_like(units):
_unpack_netcdf_time_units(units)
except ValueError:
return False
return True
return "datetime"
else:
return any(tstr == units for tstr in time_strings)
return "timedelta" if any(tstr == units for tstr in time_strings) else False


def _check_fill_values(attrs, name, dtype):
Expand Down Expand Up @@ -367,6 +370,14 @@ def _encode_unsigned_fill_value(
class CFMaskCoder(VariableCoder):
"""Mask or unmask fill values according to CF conventions."""

def __init__(
self,
decode_times: bool | CFDatetimeCoder = False,
decode_timedelta: bool | CFTimedeltaCoder = False,
) -> None:
self.decode_times = decode_times
self.decode_timedelta = decode_timedelta

def encode(self, variable: Variable, name: T_Name = None):
dims, data, attrs, encoding = unpack_for_encoding(variable)

Expand All @@ -393,33 +404,50 @@ def encode(self, variable: Variable, name: T_Name = None):

if fv_exists:
# Ensure _FillValue is cast to same dtype as data's
# but not for packed data
encoding["_FillValue"] = (
_encode_unsigned_fill_value(name, fv, dtype)
if has_unsigned
else dtype.type(fv)
if "add_offset" not in encoding and "scale_factor" not in encoding
else fv
)
fill_value = pop_to(encoding, attrs, "_FillValue", name=name)

if mv_exists:
# try to use _FillValue, if it exists to align both values
# or use missing_value and ensure it's cast to same dtype as data's
# but not for packed data
encoding["missing_value"] = attrs.get(
"_FillValue",
(
_encode_unsigned_fill_value(name, mv, dtype)
if has_unsigned
else dtype.type(mv)
if "add_offset" not in encoding and "scale_factor" not in encoding
else mv
),
)
fill_value = pop_to(encoding, attrs, "missing_value", name=name)

# apply fillna
if fill_value is not None and not pd.isnull(fill_value):
# special case DateTime to properly handle NaT
if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
data = duck_array_ops.where(
data != np.iinfo(np.int64).min, data, fill_value
)
if _is_time_like(attrs.get("units")):
if data.dtype.kind in "iu":
data = duck_array_ops.where(
data != np.iinfo(np.int64).min, data, fill_value
)
else:
# if we have float data (data was packed prior masking)
# we just fillna
data = duck_array_ops.fillna(data, fill_value)
# but if the fill_value is of integer type
# we need to round and cast
if np.array(fill_value).dtype.kind in "iu":
data = duck_array_ops.astype(
duck_array_ops.around(data), type(fill_value)
)
else:
data = duck_array_ops.fillna(data, fill_value)

Expand Down Expand Up @@ -458,9 +486,15 @@ def decode(self, variable: Variable, name: T_Name = None):

if encoded_fill_values:
# special case DateTime to properly handle NaT
# we need to check if time-like will be decoded or not
# in further processing
dtype: np.typing.DTypeLike
decoded_fill_value: Any
if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
is_time_like = _is_time_like(attrs.get("units"))
if (
(is_time_like == "datetime" and self.decode_times)
or (is_time_like == "timedelta" and self.decode_timedelta)
) and data.dtype.kind in "iu":
dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min
else:
if "scale_factor" not in attrs and "add_offset" not in attrs:
Expand Down Expand Up @@ -549,6 +583,14 @@ class CFScaleOffsetCoder(VariableCoder):
decode_values = encoded_values * scale_factor + add_offset
"""

def __init__(
self,
decode_times: bool | CFDatetimeCoder = False,
decode_timedelta: bool | CFTimedeltaCoder = False,
) -> None:
self.decode_times = decode_times
self.decode_timedelta = decode_timedelta

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
dims, data, attrs, encoding = unpack_for_encoding(variable)

Expand Down Expand Up @@ -580,8 +622,13 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
add_offset = np.asarray(add_offset).item()
# if we have a _FillValue/masked_value we already have the wanted
# floating point dtype here (via CFMaskCoder), so no check is necessary
# only check in other cases
# only check in other cases and for time-like
dtype = data.dtype
is_time_like = _is_time_like(attrs.get("units"))
if (is_time_like == "datetime" and self.decode_times) or (
is_time_like == "timedelta" and self.decode_timedelta
):
dtype = _choose_float_dtype(dtype, encoding)
if "_FillValue" not in encoding and "missing_value" not in encoding:
dtype = _choose_float_dtype(dtype, encoding)

Expand Down
8 changes: 6 additions & 2 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ def decode_cf_variable(

if mask_and_scale:
for coder in [
variables.CFMaskCoder(),
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(
decode_times=decode_times, decode_timedelta=decode_timedelta
),
variables.CFScaleOffsetCoder(
decode_times=decode_times, decode_timedelta=decode_timedelta
),
]:
var = coder.decode(var, name=name)

Expand Down
36 changes: 28 additions & 8 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,9 @@ def test_cf_timedelta_2d() -> None:


@pytest.mark.parametrize("encoding_unit", FREQUENCIES_TO_ENCODING_UNITS.values())
def test_decode_cf_timedelta_time_unit(time_unit, encoding_unit) -> None:
def test_decode_cf_timedelta_time_unit(
time_unit: PDDatetimeUnitOptions, encoding_unit
) -> None:
encoded = 1
encoding_unit_as_numpy = _netcdf_to_numpy_timeunit(encoding_unit)
if np.timedelta64(1, time_unit) > np.timedelta64(1, encoding_unit_as_numpy):
Expand All @@ -652,7 +654,9 @@ def test_decode_cf_timedelta_time_unit(time_unit, encoding_unit) -> None:
assert result.dtype == expected.dtype


def test_decode_cf_timedelta_time_unit_out_of_bounds(time_unit) -> None:
def test_decode_cf_timedelta_time_unit_out_of_bounds(
time_unit: PDDatetimeUnitOptions,
) -> None:
# Define a scale factor that will guarantee overflow with the given
# time_unit.
scale_factor = np.timedelta64(1, time_unit) // np.timedelta64(1, "ns")
Expand All @@ -661,7 +665,7 @@ def test_decode_cf_timedelta_time_unit_out_of_bounds(time_unit) -> None:
decode_cf_timedelta(encoded, "days", time_unit)


def test_cf_timedelta_roundtrip_large_value(time_unit) -> None:
def test_cf_timedelta_roundtrip_large_value(time_unit: PDDatetimeUnitOptions) -> None:
value = np.timedelta64(np.iinfo(np.int64).max, time_unit)
encoded, units = encode_cf_timedelta(value)
decoded = decode_cf_timedelta(encoded, units, time_unit=time_unit)
Expand Down Expand Up @@ -983,7 +987,7 @@ def test_use_cftime_default_standard_calendar_out_of_range(
@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS)
@pytest.mark.parametrize("units_year", [1500, 2000, 2500])
def test_use_cftime_default_non_standard_calendar(
calendar, units_year, time_unit
calendar, units_year, time_unit: PDDatetimeUnitOptions
) -> None:
from cftime import num2date

Expand Down Expand Up @@ -1429,9 +1433,9 @@ def test_roundtrip_datetime64_nanosecond_precision_warning(
) -> None:
# test warning if times can't be serialized faithfully
times = [
np.datetime64("1970-01-01T00:01:00", "ns"),
np.datetime64("NaT"),
np.datetime64("1970-01-02T00:01:00", "ns"),
np.datetime64("1970-01-01T00:01:00", time_unit),
np.datetime64("NaT", time_unit),
np.datetime64("1970-01-02T00:01:00", time_unit),
]
units = "days since 1970-01-10T01:01:00"
needed_units = "hours"
Expand Down Expand Up @@ -1620,7 +1624,9 @@ def test_roundtrip_float_times(fill_value, times, units, encoded_values) -> None
_ENCODE_DATETIME64_VIA_DASK_TESTS.values(),
ids=_ENCODE_DATETIME64_VIA_DASK_TESTS.keys(),
)
def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype, time_unit) -> None:
def test_encode_cf_datetime_datetime64_via_dask(
freq, units, dtype, time_unit: PDDatetimeUnitOptions
) -> None:
import dask.array

times_pd = pd.date_range(start="1700", freq=freq, periods=3, unit=time_unit)
Expand Down Expand Up @@ -1901,3 +1907,17 @@ def test_lazy_decode_timedelta_error() -> None:
)
with pytest.raises(OutOfBoundsTimedelta, match="overflow"):
decoded.load()


@pytest.mark.parametrize("decode_timedelta", [True, False])
@pytest.mark.parametrize("mask_and_scale", [True, False])
def test_decode_timedelta_mask_and_scale(
decode_timedelta: bool, mask_and_scale: bool
) -> None:
attrs = {"units": "days", "_FillValue": np.int16(-1), "add_offset": 100.0}
encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs)
decoded = conventions.decode_cf_variable(
"foo", encoded, mask_and_scale=mask_and_scale, decode_timedelta=decode_timedelta
)
result = conventions.encode_cf_variable(decoded, name="foo")
assert_equal(encoded, result)
5 changes: 1 addition & 4 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,16 +511,13 @@ def test_decode_dask_times(self) -> None:

@pytest.mark.parametrize("time_unit", ["s", "ms", "us", "ns"])
def test_decode_cf_time_kwargs(self, time_unit) -> None:
# todo: if we set timedelta attrs "units": "days"
# this errors on the last decode_cf wrt to the lazy_elemwise_func
# trying to convert twice
ds = Dataset.from_dict(
{
"coords": {
"timedelta": {
"data": np.array([1, 2, 3], dtype="int64"),
"dims": "timedelta",
"attrs": {"units": "seconds"},
"attrs": {"units": "days"},
},
"time": {
"data": np.array([1, 2, 3], dtype="int64"),
Expand Down
Loading