Skip to content

Commit 52c81a9

Browse files
authored
REF: implement putmask for CI/DTI/TDI/PI (#36400)
1 parent d4947a9 commit 52c81a9

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

pandas/core/arrays/categorical.py

+5
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,11 @@ def map(self, mapper):
11711171
# -------------------------------------------------------------
11721172
# Validators; ideally these can be de-duplicated
11731173

1174+
def _validate_where_value(self, value):
1175+
if is_scalar(value):
1176+
return self._validate_fill_value(value)
1177+
return self._validate_listlike(value)
1178+
11741179
def _validate_insert_value(self, value) -> int:
11751180
code = self.categories.get_indexer([value])
11761181
if (code == -1) and not (is_scalar(value) and isna(value)):

pandas/core/indexes/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -4232,9 +4232,6 @@ def putmask(self, mask, value):
42324232
try:
42334233
converted = self._validate_fill_value(value)
42344234
np.putmask(values, mask, converted)
4235-
if is_period_dtype(self.dtype):
4236-
# .values cast to object, so we need to cast back
4237-
values = type(self)(values)._data
42384235
return self._shallow_copy(values)
42394236
except (ValueError, TypeError) as err:
42404237
if is_object_dtype(self):

pandas/core/indexes/category.py

+11
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,17 @@ def where(self, cond, other=None):
422422
cat = Categorical(values, dtype=self.dtype)
423423
return type(self)._simple_new(cat, name=self.name)
424424

425+
def putmask(self, mask, value):
426+
try:
427+
code_value = self._data._validate_where_value(value)
428+
except (TypeError, ValueError):
429+
return self.astype(object).putmask(mask, value)
430+
431+
codes = self._data._ndarray.copy()
432+
np.putmask(codes, mask, code_value)
433+
cat = self._data._from_backing_data(codes)
434+
return type(self)._simple_new(cat, name=self.name)
435+
425436
def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
426437
"""
427438
Create index with target's values (move/add/delete values as necessary)

pandas/core/indexes/datetimelike.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,18 @@ def where(self, cond, other=None):
474474
raise TypeError(f"Where requires matching dtype, not {oth}") from err
475475

476476
result = np.where(cond, values, other).astype("i8")
477-
arr = type(self._data)._simple_new(result, dtype=self.dtype)
477+
arr = self._data._from_backing_data(result)
478+
return type(self)._simple_new(arr, name=self.name)
479+
480+
def putmask(self, mask, value):
481+
try:
482+
value = self._data._validate_where_value(value)
483+
except (TypeError, ValueError):
484+
return self.astype(object).putmask(mask, value)
485+
486+
result = self._data._ndarray.copy()
487+
np.putmask(result, mask, value)
488+
arr = self._data._from_backing_data(result)
478489
return type(self)._simple_new(arr, name=self.name)
479490

480491
def _summary(self, name=None) -> str:

pandas/tests/indexes/common.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -846,16 +846,17 @@ def test_map_str(self):
846846
def test_putmask_with_wrong_mask(self):
847847
# GH18368
848848
index = self.create_index()
849+
fill = index[0]
849850

850851
msg = "putmask: mask and data must be the same size"
851852
with pytest.raises(ValueError, match=msg):
852-
index.putmask(np.ones(len(index) + 1, np.bool_), 1)
853+
index.putmask(np.ones(len(index) + 1, np.bool_), fill)
853854

854855
with pytest.raises(ValueError, match=msg):
855-
index.putmask(np.ones(len(index) - 1, np.bool_), 1)
856+
index.putmask(np.ones(len(index) - 1, np.bool_), fill)
856857

857858
with pytest.raises(ValueError, match=msg):
858-
index.putmask("foo", 1)
859+
index.putmask("foo", fill)
859860

860861
@pytest.mark.parametrize("copy", [True, False])
861862
@pytest.mark.parametrize("name", [None, "foo"])

0 commit comments

Comments
 (0)