Skip to content

Commit 3e2f096

Browse files
committed
PERF: add _simple_new method to masked arrays
1 parent 3d0d0fa commit 3e2f096

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

pandas/core/arrays/boolean.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pandas._typing import (
3131
Dtype,
3232
DtypeObj,
33+
Self,
3334
npt,
3435
type_t,
3536
)
@@ -296,6 +297,12 @@ class BooleanArray(BaseMaskedArray):
296297
_TRUE_VALUES = {"True", "TRUE", "true", "1", "1.0"}
297298
_FALSE_VALUES = {"False", "FALSE", "false", "0", "0.0"}
298299

300+
@classmethod
301+
def _simple_new(cls, values: np.ndarray, mask: npt.NDArray[np.bool_]) -> Self:
302+
result = super()._simple_new(values, mask)
303+
result._dtype = BooleanDtype()
304+
return result
305+
299306
def __init__(
300307
self, values: np.ndarray, mask: np.ndarray, copy: bool = False
301308
) -> None:
@@ -390,7 +397,7 @@ def _accumulate(
390397
if name in ("cummin", "cummax"):
391398
op = getattr(masked_accumulations, name)
392399
data, mask = op(data, mask, skipna=skipna, **kwargs)
393-
return type(self)(data, mask, copy=False)
400+
return self._simple_new(data, mask)
394401
else:
395402
from pandas.core.arrays import IntegerArray
396403

pandas/core/arrays/masked.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ class BaseMaskedArray(OpsMixin, ExtensionArray):
108108
_truthy_value = Scalar # bool(_truthy_value) = True
109109
_falsey_value = Scalar # bool(_falsey_value) = False
110110

111+
@classmethod
112+
def _simple_new(cls, values: np.ndarray, mask: npt.NDArray[np.bool_]) -> Self:
113+
result = BaseMaskedArray.__new__(cls)
114+
result._data = values
115+
result._mask = mask
116+
return result
117+
111118
def __init__(
112119
self, values: np.ndarray, mask: npt.NDArray[np.bool_], copy: bool = False
113120
) -> None:
@@ -154,7 +161,7 @@ def __getitem__(self, item: PositionalIndexer) -> Self | Any:
154161
return self.dtype.na_value
155162
return self._data[item]
156163

157-
return type(self)(self._data[item], newmask)
164+
return self._simple_new(self._data[item], newmask)
158165

159166
@doc(ExtensionArray.fillna)
160167
@doc(ExtensionArray.fillna)
@@ -177,7 +184,7 @@ def fillna(self, value=None, method=None, limit: int | None = None) -> Self:
177184
npvalues = self._data.copy().T
178185
new_mask = mask.copy().T
179186
func(npvalues, limit=limit, mask=new_mask)
180-
return type(self)(npvalues.T, new_mask.T)
187+
return self._simple_new(npvalues.T, new_mask.T)
181188
else:
182189
# fill with value
183190
new_values = self.copy()
@@ -266,17 +273,17 @@ def ndim(self) -> int:
266273
def swapaxes(self, axis1, axis2) -> Self:
267274
data = self._data.swapaxes(axis1, axis2)
268275
mask = self._mask.swapaxes(axis1, axis2)
269-
return type(self)(data, mask)
276+
return self._simple_new(data, mask)
270277

271278
def delete(self, loc, axis: AxisInt = 0) -> Self:
272279
data = np.delete(self._data, loc, axis=axis)
273280
mask = np.delete(self._mask, loc, axis=axis)
274-
return type(self)(data, mask)
281+
return self._simple_new(data, mask)
275282

276283
def reshape(self, *args, **kwargs) -> Self:
277284
data = self._data.reshape(*args, **kwargs)
278285
mask = self._mask.reshape(*args, **kwargs)
279-
return type(self)(data, mask)
286+
return self._simple_new(data, mask)
280287

281288
def ravel(self, *args, **kwargs) -> Self:
282289
# TODO: need to make sure we have the same order for data/mask
@@ -286,7 +293,7 @@ def ravel(self, *args, **kwargs) -> Self:
286293

287294
@property
288295
def T(self) -> Self:
289-
return type(self)(self._data.T, self._mask.T)
296+
return self._simple_new(self._data.T, self._mask.T)
290297

291298
def round(self, decimals: int = 0, *args, **kwargs):
292299
"""
@@ -322,16 +329,16 @@ def round(self, decimals: int = 0, *args, **kwargs):
322329
# Unary Methods
323330

324331
def __invert__(self) -> Self:
325-
return type(self)(~self._data, self._mask.copy())
332+
return self._simple_new(~self._data, self._mask.copy())
326333

327334
def __neg__(self) -> Self:
328-
return type(self)(-self._data, self._mask.copy())
335+
return self._simple_new(-self._data, self._mask.copy())
329336

330337
def __pos__(self) -> Self:
331338
return self.copy()
332339

333340
def __abs__(self) -> Self:
334-
return type(self)(abs(self._data), self._mask.copy())
341+
return self._simple_new(abs(self._data), self._mask.copy())
335342

336343
# ------------------------------------------------------------------
337344

@@ -864,7 +871,7 @@ def take(
864871
result[fill_mask] = fill_value
865872
mask = mask ^ fill_mask
866873

867-
return type(self)(result, mask, copy=False)
874+
return self._simple_new(result, mask)
868875

869876
# error: Return type "BooleanArray" of "isin" incompatible with return type
870877
# "ndarray" in supertype "ExtensionArray"
@@ -889,10 +896,9 @@ def isin(self, values) -> BooleanArray: # type: ignore[override]
889896
return BooleanArray(result, mask, copy=False)
890897

891898
def copy(self) -> Self:
892-
data, mask = self._data, self._mask
893-
data = data.copy()
894-
mask = mask.copy()
895-
return type(self)(data, mask, copy=False)
899+
data = self._data.copy()
900+
mask = self._mask.copy()
901+
return self._simple_new(data, mask)
896902

897903
def unique(self) -> Self:
898904
"""
@@ -903,7 +909,7 @@ def unique(self) -> Self:
903909
uniques : BaseMaskedArray
904910
"""
905911
uniques, mask = algos.unique_with_mask(self._data, self._mask)
906-
return type(self)(uniques, mask, copy=False)
912+
return self._simple_new(uniques, mask)
907913

908914
@doc(ExtensionArray.searchsorted)
909915
def searchsorted(
@@ -955,7 +961,7 @@ def factorize(
955961
# dummy value for uniques; not used since uniques_mask will be True
956962
uniques = np.insert(uniques, na_code, 0)
957963
uniques_mask[na_code] = True
958-
uniques_ea = type(self)(uniques, uniques_mask)
964+
uniques_ea = self._simple_new(uniques, uniques_mask)
959965

960966
return codes, uniques_ea
961967

@@ -1410,7 +1416,7 @@ def _accumulate(
14101416
op = getattr(masked_accumulations, name)
14111417
data, mask = op(data, mask, skipna=skipna, **kwargs)
14121418

1413-
return type(self)(data, mask, copy=False)
1419+
return self._simple_new(data, mask)
14141420

14151421
# ------------------------------------------------------------------
14161422
# GroupBy Methods

0 commit comments

Comments
 (0)