Skip to content

Commit f3ffab7

Browse files
mathausemax-sixty
andauthored
Fix bool weights (#4075)
* add tests * weights: bool -> int * whats new * Apply suggestions from code review * avoid unecessary copy Co-authored-by: Maximilian Roos <[email protected]>
1 parent 19b0886 commit f3ffab7

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ Bug fixes
119119
- Fix bug in time parsing failing to fall back to cftime. This was causing time
120120
variables with a time unit of `'msecs'` to fail to parse. (:pull:`3998`)
121121
By `Ryan May <https://github.com/dopplershift>`_.
122+
- Fix weighted mean when passing boolean weights (:issue:`4074`).
123+
By `Mathias Hauser <https://github.com/mathause>`_.
122124
- Fix html repr in untrusted notebooks: fallback to plain text repr. (:pull:`4053`)
123125
By `Benoit Bovy <https://github.com/benbovy>`_.
124126

@@ -186,7 +188,7 @@ New Features
186188

187189
- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted`
188190
and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`).
189-
By `Mathias Hauser <https://github.com/mathause>`_
191+
By `Mathias Hauser <https://github.com/mathause>`_.
190192
- The new jupyter notebook repr (``Dataset._repr_html_`` and
191193
``DataArray._repr_html_``) (introduced in 0.14.1) is now on by default. To
192194
disable, use ``xarray.set_options(display_style="text")``.

xarray/core/weighted.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,14 @@ def _sum_of_weights(
142142
# we need to mask data values that are nan; else the weights are wrong
143143
mask = da.notnull()
144144

145-
sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False)
145+
# bool -> int, because ``xr.dot([True, True], [True, True])`` -> True
146+
# (and not 2); GH4074
147+
if self.weights.dtype == bool:
148+
sum_of_weights = self._reduce(
149+
mask, self.weights.astype(int), dim=dim, skipna=False
150+
)
151+
else:
152+
sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False)
146153

147154
# 0-weights are not valid
148155
valid_weights = sum_of_weights != 0.0

xarray/tests/test_weighted.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ def test_weighted_sum_of_weights_nan(weights, expected):
5959
assert_equal(expected, result)
6060

6161

62+
def test_weighted_sum_of_weights_bool():
63+
# https://github.com/pydata/xarray/issues/4074
64+
65+
da = DataArray([1, 2])
66+
weights = DataArray([True, True])
67+
result = da.weighted(weights).sum_of_weights()
68+
69+
expected = DataArray(2)
70+
71+
assert_equal(expected, result)
72+
73+
6274
@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
6375
@pytest.mark.parametrize("factor", [0, 1, 3.14])
6476
@pytest.mark.parametrize("skipna", (True, False))
@@ -158,6 +170,17 @@ def test_weighted_mean_nan(weights, expected, skipna):
158170
assert_equal(expected, result)
159171

160172

173+
def test_weighted_mean_bool():
174+
# https://github.com/pydata/xarray/issues/4074
175+
da = DataArray([1, 1])
176+
weights = DataArray([True, True])
177+
expected = DataArray(1)
178+
179+
result = da.weighted(weights).mean()
180+
181+
assert_equal(expected, result)
182+
183+
161184
def expected_weighted(da, weights, dim, skipna, operation):
162185
"""
163186
Generate expected result using ``*`` and ``sum``. This is checked against

0 commit comments

Comments
 (0)