Skip to content

Commit bc33fda

Browse files
keewisdcherian
andauthored
explicitly cast the dtype of where's condition parameter to bool (#10087)
* explicitly cast the dtype of `condition` to `bool` * cast `condition` to bool in every case for `where` * don't pass a `DataArray` to `where` * use strings to specify the dtype for backwards compat * revert the strings and instead ignore the warning * typo * restrict to just numpy * unrestrict * fall back to `xp.bool_` if `xp.bool` doesn't exist * unskip the `where` test * reverse to avoid warnings * remove the outdated ignore --------- Co-authored-by: Deepak Cherian <[email protected]>
1 parent 1e1938f commit bc33fda

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

xarray/core/duck_array_ops.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,17 @@ def empty_like(a, **kwargs):
224224
return xp.empty_like(a, **kwargs)
225225

226226

227-
def astype(data, dtype, **kwargs):
228-
if hasattr(data, "__array_namespace__"):
227+
def astype(data, dtype, *, xp=None, **kwargs):
228+
if not hasattr(data, "__array_namespace__") and xp is None:
229+
return data.astype(dtype, **kwargs)
230+
231+
if xp is None:
229232
xp = get_array_namespace(data)
230-
if xp == np:
231-
# numpy currently doesn't have a astype:
232-
return data.astype(dtype, **kwargs)
233-
return xp.astype(data, dtype, **kwargs)
234-
return data.astype(dtype, **kwargs)
233+
234+
if xp == np:
235+
# numpy currently doesn't have a astype:
236+
return data.astype(dtype, **kwargs)
237+
return xp.astype(data, dtype, **kwargs)
235238

236239

237240
def asarray(data, xp=np, dtype=None):
@@ -373,6 +376,13 @@ def sum_where(data, axis=None, dtype=None, where=None):
373376
def where(condition, x, y):
374377
"""Three argument where() with better dtype promotion rules."""
375378
xp = get_array_namespace(condition, x, y)
379+
380+
dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool
381+
if not is_duck_array(condition):
382+
condition = asarray(condition, dtype=dtype, xp=xp)
383+
else:
384+
condition = astype(condition, dtype=dtype, xp=xp)
385+
376386
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
377387

378388

xarray/core/groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def factorize(self) -> EncodedGroups:
524524
# Restore these after the raveling
525525
broadcasted_masks = broadcast(*masks)
526526
mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type]
527-
_flatcodes = where(mask, -1, _flatcodes)
527+
_flatcodes = where(mask.data, -1, _flatcodes)
528528

529529
full_index = pd.MultiIndex.from_product(
530530
(grouper.full_index.values for grouper in groupers),

xarray/tests/test_array_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
139139
assert_equal(actual, expected)
140140

141141

142-
@pytest.mark.skip
143142
def test_where() -> None:
144143
np_arr = xr.DataArray(np.array([1, 0]), dims="x")
145144
xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")

0 commit comments

Comments
 (0)