Skip to content

Commit 0a960de

Browse files
committed
Hande min_count=0
1 parent 24dc7fd commit 0a960de

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

flox/aggregations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,15 @@ def _initialize_aggregation(
557557
assert isinstance(finalize_kwargs, dict)
558558
agg.finalize_kwargs = finalize_kwargs
559559

560+
if min_count is None:
561+
min_count = 0
562+
560563
# This is needed for the dask pathway.
561564
# Because we use intermediate fill_value since a group could be
562565
# absent in one block, but present in another block
563566
# We set it for numpy to get nansum, nanprod tests to pass
564567
# where the identity element is 0, 1
565-
if min_count is not None:
568+
if min_count > 0:
566569
agg.min_count = min_count
567570
agg.chunk += ("nanlen",)
568571
agg.numpy += ("nanlen",)

flox/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def _finalize_results(
849849
"""
850850
squeezed = _squeeze_results(results, axis)
851851

852-
if agg.min_count is not None:
852+
if agg.min_count > 0:
853853
counts = squeezed["intermediates"][-1]
854854
squeezed["intermediates"] = squeezed["intermediates"][:-1]
855855

@@ -860,7 +860,7 @@ def _finalize_results(
860860
else:
861861
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
862862

863-
if agg.min_count is not None:
863+
if agg.min_count > 0:
864864
count_mask = counts < agg.min_count
865865
if count_mask.any():
866866
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
@@ -1898,7 +1898,12 @@ def groupby_reduce(
18981898
min_count = 1
18991899

19001900
# TODO: set in xarray?
1901-
if min_count is not None and func in ["nansum", "nanprod"] and fill_value is None:
1901+
if (
1902+
min_count is not None
1903+
and min_count > 0
1904+
and func in ["nansum", "nanprod"]
1905+
and fill_value is None
1906+
):
19021907
# nansum, nanprod have fill_value=0, 1
19031908
# overwrite than when min_count is set
19041909
fill_value = np.nan

0 commit comments

Comments
 (0)