Skip to content

Commit dfb968f

Browse files
committed
Add first, last property test
Closes #29
1 parent f38dd19 commit dfb968f

File tree

4 files changed

+160
-33
lines changed

4 files changed

+160
-33
lines changed

flox/aggregations.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def generic_aggregate(
6969
if func == "identity":
7070
return array
7171

72+
if func in ["nanfirst", "nanlast"] and array.dtype.kind in "US":
73+
func = func[3:]
74+
7275
if engine == "flox":
7376
try:
7477
method = getattr(aggregate_flox, func)
@@ -144,6 +147,8 @@ def _maybe_promote_int(dtype) -> np.dtype:
144147

145148
def _get_fill_value(dtype, fill_value):
146149
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150+
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
151+
return ""
147152
if fill_value == dtypes.INF or fill_value is None:
148153
return dtypes.get_pos_infinity(dtype, max_for_int=True)
149154
if fill_value == dtypes.NINF:
@@ -516,10 +521,10 @@ def _pick_second(*x):
516521
final_dtype=np.intp,
517522
)
518523

519-
first = Aggregation("first", chunk=None, combine=None, fill_value=0)
520-
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
521-
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
522-
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
524+
first = Aggregation("first", chunk=None, combine=None, fill_value=None)
525+
last = Aggregation("last", chunk=None, combine=None, fill_value=None)
526+
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA)
527+
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA)
523528

524529
all_ = Aggregation(
525530
"all",
@@ -808,7 +813,7 @@ def _initialize_aggregation(
808813
)
809814

810815
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
811-
if agg.name not in ["min", "max", "nanmin", "nanmax"]:
816+
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
812817
final_dtype = _maybe_promote_int(final_dtype)
813818
agg.dtype = {
814819
"user": dtype, # Save to automatically choose an engine

flox/xrdtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def get_pos_infinity(dtype, max_for_int=False):
100100

101101
if issubclass(dtype.type, np.integer):
102102
if max_for_int:
103+
dtype = np.int64 if dtype.kind in "Mm" else dtype
103104
return np.iinfo(dtype).max
104105
else:
105106
return np.inf

tests/__init__.py

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,32 @@ def assert_equal(a, b, tolerance=None):
101101
else:
102102
tolerance = {}
103103

104-
if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
105-
# sometimes it's nice to see values and shapes
106-
# rather than being dropped into some file in dask
107-
np.testing.assert_allclose(a, b, **tolerance)
108-
# does some validation of the dask graph
109-
da.utils.assert_eq(a, b, equal_nan=True)
104+
# Always run the numpy comparison first, so that we get nice error messages with dask.
105+
# sometimes it's nice to see values and shapes
106+
# rather than being dropped into some file in dask
107+
if a.dtype != b.dtype:
108+
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
109+
110+
if has_dask:
111+
a_eager = a.compute() if isinstance(a, dask_array_type) else a
112+
b_eager = b.compute() if isinstance(b, dask_array_type) else b
113+
114+
if a.dtype.kind in "SUMm":
115+
np.testing.assert_equal(a_eager, b_eager)
110116
else:
111-
if a.dtype != b.dtype:
112-
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
117+
np.testing.assert_allclose(a_eager, b_eager, equal_nan=True, **tolerance)
113118

114-
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)
119+
if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
120+
# does some validation of the dask graph
121+
dask_assert_eq(a, b, equal_nan=True)
115122

116123

117124
def assert_equal_tuple(a, b):
118125
"""assert_equal for .blocks indexing tuples"""
119126
assert len(a) == len(b)
120127

121128
for a_, b_ in zip(a, b):
122-
assert type(a_) == type(b_)
129+
assert type(a_) is type(b_)
123130
if isinstance(a_, np.ndarray):
124131
np.testing.assert_array_equal(a_, b_)
125132
else:
@@ -156,3 +163,91 @@ def assert_equal_tuple(a, b):
156163
"quantile",
157164
"nanquantile",
158165
) + tuple(SCIPY_STATS_FUNCS)
166+
167+
168+
def dask_assert_eq(
169+
a,
170+
b,
171+
check_shape=True,
172+
check_graph=True,
173+
check_meta=True,
174+
check_chunks=True,
175+
check_ndim=True,
176+
check_type=True,
177+
check_dtype=True,
178+
equal_nan=True,
179+
scheduler="sync",
180+
**kwargs,
181+
):
182+
"""dask.array.utils.assert_eq modified to skip value checks. Their code is buggy for some dtypes.
183+
We just check values through numpy and care about validating the graph in this function."""
184+
from dask.array.utils import _get_dt_meta_computed
185+
186+
a_original = a
187+
b_original = b
188+
189+
if isinstance(a, (list, int, float)):
190+
a = np.array(a)
191+
if isinstance(b, (list, int, float)):
192+
b = np.array(b)
193+
194+
a, adt, a_meta, a_computed = _get_dt_meta_computed(
195+
a,
196+
check_shape=check_shape,
197+
check_graph=check_graph,
198+
check_chunks=check_chunks,
199+
check_ndim=check_ndim,
200+
scheduler=scheduler,
201+
)
202+
b, bdt, b_meta, b_computed = _get_dt_meta_computed(
203+
b,
204+
check_shape=check_shape,
205+
check_graph=check_graph,
206+
check_chunks=check_chunks,
207+
check_ndim=check_ndim,
208+
scheduler=scheduler,
209+
)
210+
211+
if check_type:
212+
_a = a if a.shape else a.item()
213+
_b = b if b.shape else b.item()
214+
assert type(_a) is type(_b), f"a and b have different types (a: {type(_a)}, b: {type(_b)})"
215+
if check_meta:
216+
if hasattr(a, "_meta") and hasattr(b, "_meta"):
217+
dask_assert_eq(a._meta, b._meta)
218+
if hasattr(a_original, "_meta"):
219+
msg = (
220+
f"compute()-ing 'a' changes its number of dimensions "
221+
f"(before: {a_original._meta.ndim}, after: {a.ndim})"
222+
)
223+
assert a_original._meta.ndim == a.ndim, msg
224+
if a_meta is not None:
225+
msg = (
226+
f"compute()-ing 'a' changes its type "
227+
f"(before: {type(a_original._meta)}, after: {type(a_meta)})"
228+
)
229+
assert type(a_original._meta) is type(a_meta), msg
230+
if not (np.isscalar(a_meta) or np.isscalar(a_computed)):
231+
msg = (
232+
f"compute()-ing 'a' results in a different type than implied by its metadata "
233+
f"(meta: {type(a_meta)}, computed: {type(a_computed)})"
234+
)
235+
assert type(a_meta) is type(a_computed), msg
236+
if hasattr(b_original, "_meta"):
237+
msg = (
238+
f"compute()-ing 'b' changes its number of dimensions "
239+
f"(before: {b_original._meta.ndim}, after: {b.ndim})"
240+
)
241+
assert b_original._meta.ndim == b.ndim, msg
242+
if b_meta is not None:
243+
msg = (
244+
f"compute()-ing 'b' changes its type "
245+
f"(before: {type(b_original._meta)}, after: {type(b_meta)})"
246+
)
247+
assert type(b_original._meta) is type(b_meta), msg
248+
if not (np.isscalar(b_meta) or np.isscalar(b_computed)):
249+
msg = (
250+
f"compute()-ing 'b' results in a different type than implied by its metadata "
251+
f"(meta: {type(b_meta)}, computed: {type(b_computed)})"
252+
)
253+
assert type(b_meta) is type(b_computed), msg

tests/test_properties.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
5959
func_st = st.sampled_from(
6060
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
6161
)
62+
numeric_arrays = npst.arrays(
63+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
64+
)
65+
all_arrays = npst.arrays(
66+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
67+
)
6268

6369

6470
def by_arrays(shape):
@@ -81,13 +87,7 @@ def not_overflowing_array(array) -> bool:
8187
return result
8288

8389

84-
@given(
85-
array=npst.arrays(
86-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
87-
),
88-
dtype=by_dtype_st,
89-
func=func_st,
90-
)
90+
@given(array=numeric_arrays, dtype=by_dtype_st, func=func_st)
9191
def test_groupby_reduce(array, dtype, func):
9292
# overflow behaviour differs between bincount and sum (for example)
9393
assume(not_overflowing_array(array))
@@ -149,17 +149,7 @@ def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
149149

150150

151151
@st.composite
152-
def chunked_arrays(
153-
draw,
154-
*,
155-
chunks=chunks,
156-
arrays=npst.arrays(
157-
elements={"allow_subnormal": False},
158-
shape=npst.array_shapes(max_side=10),
159-
dtype=array_dtype_st,
160-
),
161-
from_array=dask.array.from_array,
162-
):
152+
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
163153
array = draw(arrays)
164154
chunks = draw(chunks(shape=array.shape))
165155

@@ -216,6 +206,7 @@ def test_scans(data, array, func):
216206

217207
@given(data=st.data(), array=chunked_arrays())
218208
def test_ffill_bfill_reverse(data, array):
209+
# TODO: test NaT and timedelta, datetime
219210
assume(not_overflowing_array(np.asarray(array)))
220211
by = data.draw(by_arrays(shape=(array.shape[-1],)))
221212

@@ -230,3 +221,38 @@ def reverse(arr):
230221
backward = groupby_scan(a, by, func="bfill")
231222
forward_reversed = reverse(groupby_scan(reverse(a), reverse(by), func="ffill"))
232223
assert_equal(forward_reversed, backward)
224+
225+
226+
@given(
227+
data=st.data(),
228+
array=chunked_arrays(arrays=all_arrays),
229+
func=st.sampled_from(["first", "last", "nanfirst", "nanlast"]),
230+
)
231+
def test_first_last(data, array, func):
232+
by = data.draw(by_arrays(shape=(array.shape[-1],)))
233+
234+
INVERSES = {"first": "last", "last": "first", "nanfirst": "nanlast", "nanlast": "nanfirst"}
235+
MATES = {"first": "nanfirst", "last": "nanlast", "nanfirst": "first", "nanlast": "last"}
236+
inverse = INVERSES[func]
237+
mate = MATES[func]
238+
239+
if func in ["first", "last"]:
240+
array = array.rechunk((*array.chunks[:-1], -1))
241+
242+
for arr in [array, array.compute()]:
243+
forward, fg = groupby_reduce(arr, by, func=func, engine="flox")
244+
reverse, rg = groupby_reduce(arr[..., ::-1], by[..., ::-1], func=inverse, engine="flox")
245+
246+
assert forward.dtype == reverse.dtype
247+
assert forward.dtype == arr.dtype
248+
249+
assert_equal(fg, rg)
250+
assert_equal(forward, reverse)
251+
252+
if arr.dtype.kind == "f" and not np.isnan(array.compute()).any():
253+
if mate in ["first", "last"]:
254+
array = array.rechunk((*array.chunks[:-1], -1))
255+
256+
first, _ = groupby_reduce(array, by, func=func, engine="flox")
257+
second, _ = groupby_reduce(array, by, func=mate, engine="flox")
258+
assert_equal(first, second)

0 commit comments

Comments
 (0)