diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 46428dcf462ea..d603573d1f6a7 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -185,6 +185,7 @@ Other enhancements - Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`) - Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`) - :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`) +- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string())`` (:issue:`50325`) - :meth:`Series.drop_duplicates` has gained ``ignore_index`` keyword to reset index (:issue:`48304`) - :meth:`Series.dropna` and :meth:`DataFrame.dropna` has gained ``ignore_index`` keyword to reset index (:issue:`31725`) - Improved error message in :func:`to_datetime` for non-ISO8601 formats, informing users about the position of the first error (:issue:`50361`) diff --git a/pandas/conftest.py b/pandas/conftest.py index 64a8f0f9efc1d..c7d905a30d203 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -70,6 +70,7 @@ Index, MultiIndex, ) +from pandas.core.strings.accessor import StringMethods try: import pyarrow as pa @@ -1947,3 +1948,124 @@ def warsaw(request): tzinfo for Europe/Warsaw using pytz, dateutil, or zoneinfo. """ return request.param + + +_any_string_method = [ + ("cat", (), {"sep": ","}), + ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}), + ("center", (10,), {}), + ("contains", ("a",), {}), + ("count", ("a",), {}), + ("decode", ("UTF-8",), {}), + ("encode", ("UTF-8",), {}), + ("endswith", ("a",), {}), + ("endswith", ("a",), {"na": True}), + ("endswith", ("a",), {"na": False}), + ("extract", ("([a-z]*)",), {"expand": False}), + ("extract", ("([a-z]*)",), {"expand": True}), + ("extractall", ("([a-z]*)",), {}), + ("find", ("a",), {}), + ("findall", ("a",), {}), + ("get", (0,), {}), + # because "index" (and "rindex") fail intentionally + # if the string is not found, search only for empty string + ("index", ("",), {}), + ("join", (",",), {}), + ("ljust", (10,), {}), + ("match", ("a",), {}), + ("fullmatch", ("a",), {}), + ("normalize", ("NFC",), {}), + ("pad", (10,), {}), + ("partition", (" ",), {"expand": False}), + ("partition", (" ",), {"expand": True}), + ("repeat", (3,), {}), + ("replace", ("a", "z"), {}), + ("rfind", ("a",), {}), + ("rindex", ("",), {}), + ("rjust", (10,), {}), + ("rpartition", (" ",), {"expand": False}), + ("rpartition", (" ",), {"expand": True}), + ("slice", (0, 1), {}), + ("slice_replace", (0, 1, "z"), {}), + ("split", (" ",), {"expand": False}), + # ("split", (" ",), {"expand": True}), + ("startswith", ("a",), {}), + ("startswith", ("a",), {"na": True}), + ("startswith", ("a",), {"na": False}), + ("removeprefix", ("a",), {}), + ("removesuffix", ("a",), {}), + # translating unicode points of "a" to "d" + ("translate", ({97: 100},), {}), + ("wrap", (2,), {}), + ("zfill", (10,), {}), +] + list( + zip( + [ + # methods without positional arguments: zip with empty tuple and empty dict + "capitalize", + "cat", + "get_dummies", + "isalnum", + "isalpha", + "isdecimal", + "isdigit", + "islower", + "isnumeric", + "isspace", + "istitle", + "isupper", + "len", + "lower", + "lstrip", + "partition", + "rpartition", + "rsplit", + "rstrip", + "slice", + "slice_replace", + "split", + "strip", + "swapcase", + "title", + "upper", + "casefold", + ], + [()] * 100, + [{}] * 100, + ) +) +ids, _, _ = zip(*_any_string_method) # use method name as fixture-id +missing_methods = {f for f in dir(StringMethods) if not f.startswith("_")} - set(ids) + +# test that the above list captures all methods of StringMethods +assert not missing_methods + + +@pytest.fixture(params=_any_string_method, ids=ids) +def any_string_method(request): + """ + Fixture for all public methods of `StringMethods` + + This fixture returns a tuple of the method name and sample arguments + necessary to call the method. + + Returns + ------- + method_name : str + The name of the method in `StringMethods` + args : tuple + Sample values for the positional arguments + kwargs : dict + Sample values for the keyword arguments + + Examples + -------- + >>> def test_something(any_string_method): + ... s = Series(['a', 'b', np.nan, 'd']) + ... + ... method_name, args, kwargs = any_string_method + ... method = getattr(s.str, method_name) + ... # will not raise + ... method(*args, **kwargs) + """ + return request.param diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0f81ab5f7b424..9c9a49f458810 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,17 +1,23 @@ from __future__ import annotations from copy import deepcopy +import re from typing import ( TYPE_CHECKING, Any, + Callable, Literal, + Sequence, TypeVar, cast, ) import numpy as np -from pandas._libs import lib +from pandas._libs import ( + lib, + missing as libmissing, +) from pandas._typing import ( ArrayLike, AxisInt, @@ -53,6 +59,7 @@ unpack_tuple_and_ellipses, validate_indices, ) +from pandas.core.strings.object_array import ObjectStringArrayMixin if not pa_version_under6p0: import pyarrow as pa @@ -151,7 +158,7 @@ def to_pyarrow_type( return None -class ArrowExtensionArray(OpsMixin, ExtensionArray): +class ArrowExtensionArray(OpsMixin, ObjectStringArrayMixin, ExtensionArray): """ Pandas ExtensionArray backed by a PyArrow ChunkedArray. @@ -211,6 +218,7 @@ def __init__(self, values: pa.Array | pa.ChunkedArray) -> None: f"Unsupported type '{type(values)}' for ArrowExtensionArray" ) self._dtype = ArrowDtype(self._data.type) + assert self._dtype.na_value is self._str_na_value # type: ignore[has-type] @classmethod def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): @@ -1385,17 +1393,14 @@ def _replace_with_mask( ): """ Replace items selected with a mask. - Analogous to pyarrow.compute.replace_with_mask, with logic to fallback to numpy for unsupported types. - Parameters ---------- values : pa.Array or pa.ChunkedArray mask : npt.NDArray[np.bool_] or bool replacements : ArrayLike or Scalar Replacement value(s) - Returns ------- pa.Array or pa.ChunkedArray @@ -1424,3 +1429,286 @@ def _replace_with_mask( result = np.array(values, dtype=object) result[mask] = replacements return pa.array(result, type=values.type, from_pandas=True) + + # ------------------------------------------------------------------------ + # String methods interface (Series.str.) + + # asserted to match self._dtype.na_value + _str_na_value = libmissing.NA + + def _str_count(self, pat: str, flags: int = 0): + if flags: + fallback_performancewarning() + return super()._str_count(pat, flags) + return type(self)(pc.count_substring_regex(self._data, pat)) + + def _str_pad( + self, + width: int, + side: Literal["left", "right", "both"] = "left", + fillchar: str = " ", + ): + if side == "left": + pa_pad = pc.utf8_lpad + elif side == "right": + pa_pad = pc.utf8_rpad + elif side == "both": + # https://github.com/apache/arrow/issues/15053 + # pa_pad = pc.utf8_center + return super()._str_pad(width, side, fillchar) + else: + raise ValueError( + f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" + ) + return type(self)(pa_pad(self._data, width=width, padding=fillchar)) + + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): + if na_value is None: + na_value = self.dtype.na_value + + mask = self.isna() + arr = np.asarray(self) + + np_result = lib.map_infer_mask( + arr, f, mask.view("uint8"), convert=False, na_value=na_value + ) + try: + return type(self)(pa.array(np_result, mask=mask, from_pandas=True)) + except pa.ArrowInvalid: + return np_result + + def _str_contains( + self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True + ): + if flags: + fallback_performancewarning() + return super()._str_contains(pat, case, flags, na, regex) + + if regex: + pa_contains = pc.match_substring_regex + else: + pa_contains = pc.match_substring + result = pa_contains(self._data, pat, ignore_case=not case) + if not isna(na): + result = result.fill_null(na) + return type(self)(result) + + def _str_startswith(self, pat: str, na=None): + result = pc.starts_with(self._data, pattern=pat) + if na is not None: + result = result.fill_null(na) + return type(self)(result) + + def _str_endswith(self, pat: str, na=None): + result = pc.ends_with(self._data, pattern=pat) + if na is not None: + result = result.fill_null(na) + return type(self)(result) + + def _str_replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool = True, + flags: int = 0, + regex: bool = True, + ): + if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: + fallback_performancewarning() + return super()._str_replace(pat, repl, n, case, flags, regex) + + func = pc.replace_substring_regex if regex else pc.replace_substring + result = func(self._data, pattern=pat, replacement=repl, max_replacements=n) + return type(self)(result) + + def _str_repeat(self, repeats: int | Sequence[int]): + if not isinstance(repeats, int): + fallback_performancewarning() + return super()._str_repeat(repeats) + elif pa_version_under7p0: + fallback_performancewarning("7") + return super()._str_repeat(repeats) + else: + return type(self)(pc.binary_repeat(self._data, repeats)) + + def _str_match( + self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.startswith("^"): + pat = f"^{pat}" + return self._str_contains(pat, case, flags, na, regex=True) + + def _str_fullmatch( + self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.endswith("$") or pat.endswith("//$"): + pat = f"{pat}$" + return self._str_match(pat, case, flags, na) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 and end is not None: + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: + slices = self._data + result = pc.find_substring(slices, sub) + else: + fallback_performancewarning() + return super()._str_find(sub, start, end) + return type(self)(result) + + def _str_get(self, i: int): + lengths = pc.utf8_length(self._data) + if i >= 0: + out_of_bounds = pc.greater_equal(i, lengths) + start = i + stop = i + 1 + step = 1 + else: + out_of_bounds = pc.greater(-i, lengths) + start = i + stop = i - 1 + step = -1 + not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True)) + selected = pc.utf8_slice_codeunits( + self._data, start=start, stop=stop, step=step + ) + result = pa.array([None] * self._data.length(), type=self._data.type) + result = pc.if_else(not_out_of_bounds, selected, result) + return type(self)(result) + + def _str_join(self, sep: str): + if isinstance(self._dtype, ArrowDtype) and pa.types.is_list( + self.dtype.pyarrow_dtype + ): + # Check ArrowDtype as ArrowString inherits and uses StringDtype + return type(self)(pc.binary_join(self._data, sep)) + else: + return super()._str_join(sep) + + def _str_partition(self, sep: str, expand: bool): + result = super()._str_partition(sep, expand) + if expand and isinstance(result, type(self)): + # StringMethods._wrap_result needs numpy-like nested object + result = result.to_numpy(dtype=object) + return result + + def _str_rpartition(self, sep: str, expand: bool): + result = super()._str_rpartition(sep, expand) + if expand and isinstance(result, type(self)): + # StringMethods._wrap_result needs numpy-like nested object + result = result.to_numpy(dtype=object) + return result + + def _str_slice( + self, start: int | None = None, stop: int | None = None, step: int | None = None + ): + if stop is None: + # TODO: TODO: Should work once https://github.com/apache/arrow/issues/14991 + # is fixed + return super()._str_slice(start, stop, step) + if start is None: + start = 0 + if step is None: + step = 1 + return type(self)( + pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step) + ) + + def _str_slice_replace( + self, start: int | None = None, stop: int | None = None, repl: str | None = None + ): + if stop is None: + # Not implemented as of pyarrow=10 + fallback_performancewarning() + return super()._str_slice_replace(start, stop, repl) + if repl is None: + repl = "" + if start is None: + start = 0 + return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) + + def _str_isalnum(self): + return type(self)(pc.utf8_is_alnum(self._data)) + + def _str_isalpha(self): + return type(self)(pc.utf8_is_alpha(self._data)) + + def _str_isdecimal(self): + return type(self)(pc.utf8_is_decimal(self._data)) + + def _str_isdigit(self): + return type(self)(pc.utf8_is_digit(self._data)) + + def _str_islower(self): + return type(self)(pc.utf8_is_lower(self._data)) + + def _str_isnumeric(self): + return type(self)(pc.utf8_is_numeric(self._data)) + + def _str_isspace(self): + return type(self)(pc.utf8_is_space(self._data)) + + def _str_istitle(self): + return type(self)(pc.utf8_is_title(self._data)) + + def _str_capitalize(self): + return type(self)(pc.utf8_capitalize(self._data)) + + def _str_title(self): + return type(self)(pc.utf8_title(self._data)) + + def _str_isupper(self): + return type(self)(pc.utf8_is_upper(self._data)) + + def _str_swapcase(self): + return type(self)(pc.utf8_swapcase(self._data)) + + def _str_len(self): + return type(self)(pc.utf8_length(self._data)) + + def _str_lower(self): + return type(self)(pc.utf8_lower(self._data)) + + def _str_upper(self): + return type(self)(pc.utf8_upper(self._data)) + + def _str_strip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_trim_whitespace(self._data) + else: + result = pc.utf8_trim(self._data, characters=to_strip) + return type(self)(result) + + def _str_lstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_ltrim_whitespace(self._data) + else: + result = pc.utf8_ltrim(self._data, characters=to_strip) + return type(self)(result) + + def _str_rstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_rtrim_whitespace(self._data) + else: + result = pc.utf8_rtrim(self._data, characters=to_strip) + return type(self)(result) + + # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed + # def _str_removeprefix(self, prefix: str) -> Series: + # starts_with = pc.starts_with(self._data, pattern=prefix) + # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) + # result = pc.if_else(starts_with, removed, self._data) + # return type(self)(result) + + def _str_removesuffix(self, suffix: str) -> Series: + ends_with = pc.ends_with(self._data, pattern=suffix) + removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) + result = pc.if_else(ends_with, removed, self._data) + return type(self)(result) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 4aebe61412866..fc964d181bc11 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -17,7 +17,10 @@ Scalar, npt, ) -from pandas.compat import pa_version_under6p0 +from pandas.compat import ( + pa_version_under6p0, + pa_version_under8p0, +) from pandas.core.dtypes.common import ( is_bool_dtype, @@ -278,21 +281,45 @@ def _str_contains( return super()._str_contains(pat, case, flags, na, regex) if regex: - if case is False: - fallback_performancewarning() - return super()._str_contains(pat, case, flags, na, regex) - else: - result = pc.match_substring_regex(self._data, pat) + result = pc.match_substring_regex(self._data, pat, ignore_case=not case) else: - if case: - result = pc.match_substring(self._data, pat) - else: - result = pc.match_substring(pc.utf8_upper(self._data), pat.upper()) + result = pc.match_substring(self._data, pat, ignore_case=not case) result = BooleanDtype().__from_arrow__(result) if not isna(na): result[isna(result)] = bool(na) return result + def _str_count(self, pat: str, flags: int = 0): + if flags: + fallback_performancewarning() + return super()._str_count(pat, flags) + result = pc.count_substring_regex(self._data, pat) + type_mapping = {pa.int32(): Int64Dtype(), pa.int64(): Int64Dtype()} + pd_result = result.to_pandas(types_mapper=type_mapping.get) + if pa_version_under8p0: + # Bug in pyarrow not respecting type_mapper + pd_result = pd_result.astype(Int64Dtype()) + return pd_result + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 and end is not None: + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: + slices = self._data + result = pc.find_substring(slices, sub) + else: + return super()._str_find(sub, start, end) + type_mapping = {pa.int32(): Int64Dtype(), pa.int64(): Int64Dtype()} + pd_result = result.to_pandas(types_mapper=type_mapping.get) + if pa_version_under8p0: + # Bug in pyarrow not respecting type_mapper + pd_result = pd_result.astype(Int64Dtype()) + return pd_result + def _str_startswith(self, pat: str, na=None): pat = f"^{re.escape(pat)}" return self._str_contains(pat, na=na, regex=True) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 92c36db397f60..0b6157c5f05ac 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -72,6 +72,7 @@ PeriodArray, TimedeltaArray, ) +from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.arrays.masked import ( BaseMaskedArray, BaseMaskedDtype, @@ -387,7 +388,9 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray: # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents npvalues = values._ndarray.view("M8[ns]") - elif isinstance(values.dtype, StringDtype): + elif isinstance(values.dtype, StringDtype) or ( + isinstance(values.dtype, ArrowDtype) and values.dtype.kind == "U" + ): # StringArray npvalues = values.to_numpy(object, na_value=np.nan) else: @@ -403,9 +406,11 @@ def _reconstruct_ea_result( """ Construct an ExtensionArray result from an ndarray result. """ - dtype: BaseMaskedDtype | StringDtype + dtype: BaseMaskedDtype | StringDtype | ArrowDtype - if isinstance(values.dtype, StringDtype): + if isinstance(values.dtype, StringDtype) or ( + isinstance(values.dtype, ArrowDtype) and values.dtype.kind == "U" + ): dtype = values.dtype string_array_cls = dtype.construct_array_type() return string_array_cls._from_sequence(res_values, dtype=dtype) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index fcbeaa3834d4b..bf92b6258bc72 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -169,7 +169,7 @@ get_group_index_sorter, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.io.formats.printing import ( PrettyDict, diff --git a/pandas/core/series.py b/pandas/core/series.py index 63420309f33fc..463e4824950b7 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -161,7 +161,7 @@ ensure_key_mapped, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.core.tools.datetimes import to_datetime import pandas.io.formats.format as fmt diff --git a/pandas/core/strings/__init__.py b/pandas/core/strings/__init__.py index 28aba7c9ce0b3..3f1d2615bbeba 100644 --- a/pandas/core/strings/__init__.py +++ b/pandas/core/strings/__init__.py @@ -26,8 +26,4 @@ # - PandasArray # - Categorical # - ArrowStringArray - -from pandas.core.strings.accessor import StringMethods -from pandas.core.strings.base import BaseStringArrayMethods - -__all__ = ["StringMethods", "BaseStringArrayMethods"] +# - ArrowExtensionArray diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 803dbe32bc16b..51818a7b17de3 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -175,11 +175,14 @@ class StringMethods(NoNewAttributesMixin): # * extractall def __init__(self, data) -> None: + from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.arrays.string_ import StringDtype self._inferred_dtype = self._validate(data) self._is_categorical = is_categorical_dtype(data.dtype) - self._is_string = isinstance(data.dtype, StringDtype) + self._is_string = isinstance(data.dtype, StringDtype) or ( + isinstance(data.dtype, ArrowDtype) and data.dtype.kind == "U" + ) self._data = data self._index = self._name = None @@ -286,7 +289,7 @@ def cons_row(x): # propagate nan values to match longest sequence (GH 18450) max_len = max(len(x) for x in result) result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result + x * max_len if len(x) == 0 or isna(x[0]) else x for x in result ] if not isinstance(expand, bool): @@ -3265,9 +3268,10 @@ def _result_dtype(arr): # workaround #27953 # ideally we just pass `dtype=arr.dtype` unconditionally, but this fails # when the list of values is empty. + from pandas.core.arrays.arrow import ArrowDtype from pandas.core.arrays.string_ import StringDtype - if isinstance(arr.dtype, StringDtype): + if isinstance(arr.dtype, (ArrowDtype, StringDtype)): return arr.dtype else: return object diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index d912ab59ab025..3dde59f95fc94 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -25,6 +25,7 @@ import numpy as np import pytest +from pandas._libs import lib from pandas.compat import ( PY311, is_ci_environment, @@ -67,6 +68,11 @@ def dtype(request): return ArrowDtype(pyarrow_dtype=request.param) +@pytest.fixture +def nullable_string_dtype(): + return ArrowDtype(pa.string()) + + @pytest.fixture def data(dtype): pa_dtype = dtype.pyarrow_dtype @@ -1484,6 +1490,108 @@ def test_astype_from_non_pyarrow(data): tm.assert_extension_array_equal(result, data) +@pytest.mark.filterwarnings("ignore:Falling back") +def test_string_methods(nullable_string_dtype, any_string_method): + method_name, args, kwargs = any_string_method + + data = ["a", "bb", np.nan, "ccc"] + a = pd.Series(data, dtype=object) + b = pd.Series(data, dtype=nullable_string_dtype) + + if method_name == "decode": + with pytest.raises(TypeError, match="a bytes-like object is required"): + getattr(b.str, method_name)(*args, **kwargs) + return + + expected = getattr(a.str, method_name)(*args, **kwargs) + result = getattr(b.str, method_name)(*args, **kwargs) + + if isinstance(expected, pd.Series): + if expected.dtype == "object" and lib.is_string_array( + expected.dropna().values, + ): + assert result.dtype == nullable_string_dtype + result = result.astype(object) + + elif expected.dtype == "object" and lib.is_bool_array( + expected.values, skipna=True + ): + assert result.dtype == ArrowDtype(pa.bool_()) + result = result.astype(object) + + elif expected.dtype == "bool": + assert result.dtype == ArrowDtype(pa.bool_()) + result = result.astype("bool") + + elif expected.dtype == "float" and expected.isna().any(): + assert isinstance(result.dtype, ArrowDtype) + assert pa.types.is_integer(result.dtype.pyarrow_dtype) + result = result.astype("Int64") + expected = expected.astype("Int64") + + elif pa.types.is_binary(result.dtype.pyarrow_dtype): + result = result.astype(object) + + elif pa.types.is_list(result.dtype.pyarrow_dtype): + result = result.astype(object) + + elif isinstance(expected, pd.DataFrame): + columns = expected.select_dtypes(include="object").columns + assert all(result[columns].dtypes == nullable_string_dtype) + result[columns] = result[columns].astype(object) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected,return_type", + [ + ("count", [2, None], "int32"), + ("find", [0, None], "int32"), + ("index", [0, None], "int64"), + ("rindex", [2, None], "int64"), + ], +) +def test_string_array_numeric_integer_return( + nullable_string_dtype, method, expected, return_type +): + s = pd.Series(["aba", None], dtype=nullable_string_dtype) + result = getattr(s.str, method)("a") + expected = pd.Series(expected, dtype=ArrowDtype(getattr(pa, return_type)())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ("isdigit", [False, None, True]), + ("isalpha", [True, None, False]), + ("isalnum", [True, None, True]), + ("isnumeric", [False, None, True]), + ], +) +def test_string_array_boolean_return(nullable_string_dtype, method, expected): + s = pd.Series(["a", None, "1"], dtype=nullable_string_dtype) + result = getattr(s.str, method)() + expected = pd.Series(expected, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +def test_string_array_extract(nullable_string_dtype): + # https://github.com/pandas-dev/pandas/issues/30969 + # Only expand=False & multiple groups was failing + + a = pd.Series(["a1", "b2", "cc"], dtype=nullable_string_dtype) + b = pd.Series(["a1", "b2", "cc"], dtype="object") + pat = r"(\w)(\d)" + + result = a.str.extract(pat, expand=False) + expected = b.str.extract(pat, expand=False) + assert all(result.dtypes == nullable_string_dtype) + + result = result.astype(object) + tm.assert_equal(result, expected) + + def test_astype_float_from_non_pyarrow_str(): # GH50430 ser = pd.Series(["1.0"]) diff --git a/pandas/tests/strings/conftest.py b/pandas/tests/strings/conftest.py index cdc2b876194e6..46df7c00a03a1 100644 --- a/pandas/tests/strings/conftest.py +++ b/pandas/tests/strings/conftest.py @@ -1,132 +1,6 @@ import numpy as np import pytest -from pandas import Series -from pandas.core import strings - -_any_string_method = [ - ("cat", (), {"sep": ","}), - ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}), - ("center", (10,), {}), - ("contains", ("a",), {}), - ("count", ("a",), {}), - ("decode", ("UTF-8",), {}), - ("encode", ("UTF-8",), {}), - ("endswith", ("a",), {}), - ("endswith", ("a",), {"na": True}), - ("endswith", ("a",), {"na": False}), - ("extract", ("([a-z]*)",), {"expand": False}), - ("extract", ("([a-z]*)",), {"expand": True}), - ("extractall", ("([a-z]*)",), {}), - ("find", ("a",), {}), - ("findall", ("a",), {}), - ("get", (0,), {}), - # because "index" (and "rindex") fail intentionally - # if the string is not found, search only for empty string - ("index", ("",), {}), - ("join", (",",), {}), - ("ljust", (10,), {}), - ("match", ("a",), {}), - ("fullmatch", ("a",), {}), - ("normalize", ("NFC",), {}), - ("pad", (10,), {}), - ("partition", (" ",), {"expand": False}), - ("partition", (" ",), {"expand": True}), - ("repeat", (3,), {}), - ("replace", ("a", "z"), {}), - ("rfind", ("a",), {}), - ("rindex", ("",), {}), - ("rjust", (10,), {}), - ("rpartition", (" ",), {"expand": False}), - ("rpartition", (" ",), {"expand": True}), - ("slice", (0, 1), {}), - ("slice_replace", (0, 1, "z"), {}), - ("split", (" ",), {"expand": False}), - ("split", (" ",), {"expand": True}), - ("startswith", ("a",), {}), - ("startswith", ("a",), {"na": True}), - ("startswith", ("a",), {"na": False}), - ("removeprefix", ("a",), {}), - ("removesuffix", ("a",), {}), - # translating unicode points of "a" to "d" - ("translate", ({97: 100},), {}), - ("wrap", (2,), {}), - ("zfill", (10,), {}), -] + list( - zip( - [ - # methods without positional arguments: zip with empty tuple and empty dict - "capitalize", - "cat", - "get_dummies", - "isalnum", - "isalpha", - "isdecimal", - "isdigit", - "islower", - "isnumeric", - "isspace", - "istitle", - "isupper", - "len", - "lower", - "lstrip", - "partition", - "rpartition", - "rsplit", - "rstrip", - "slice", - "slice_replace", - "split", - "strip", - "swapcase", - "title", - "upper", - "casefold", - ], - [()] * 100, - [{}] * 100, - ) -) -ids, _, _ = zip(*_any_string_method) # use method name as fixture-id -missing_methods = { - f for f in dir(strings.StringMethods) if not f.startswith("_") -} - set(ids) - -# test that the above list captures all methods of StringMethods -assert not missing_methods - - -@pytest.fixture(params=_any_string_method, ids=ids) -def any_string_method(request): - """ - Fixture for all public methods of `StringMethods` - - This fixture returns a tuple of the method name and sample arguments - necessary to call the method. - - Returns - ------- - method_name : str - The name of the method in `StringMethods` - args : tuple - Sample values for the positional arguments - kwargs : dict - Sample values for the keyword arguments - - Examples - -------- - >>> def test_something(any_string_method): - ... s = Series(['a', 'b', np.nan, 'd']) - ... - ... method_name, args, kwargs = any_string_method - ... method = getattr(s.str, method_name) - ... # will not raise - ... method(*args, **kwargs) - """ - return request.param - - # subset of the full set from pandas/conftest.py _any_allowed_skipna_inferred_dtype = [ ("string", ["a", np.nan, "c"]), diff --git a/pandas/tests/strings/test_api.py b/pandas/tests/strings/test_api.py index 7a6c7e69047bc..036eaf5f0cae4 100644 --- a/pandas/tests/strings/test_api.py +++ b/pandas/tests/strings/test_api.py @@ -8,14 +8,14 @@ _testing as tm, get_option, ) -from pandas.core import strings +from pandas.core.strings.accessor import StringMethods def test_api(any_string_dtype): # GH 6106, GH 9322 - assert Series.str is strings.StringMethods - assert isinstance(Series([""], dtype=any_string_dtype).str, strings.StringMethods) + assert Series.str is StringMethods + assert isinstance(Series([""], dtype=any_string_dtype).str, StringMethods) def test_api_mi_raises(): @@ -45,7 +45,7 @@ def test_api_per_dtype(index_or_series, dtype, any_skipna_inferred_dtype): ] if inferred_dtype in types_passing_constructor: # GH 6106 - assert isinstance(t.str, strings.StringMethods) + assert isinstance(t.str, StringMethods) else: # GH 9184, GH 23011, GH 23163 msg = "Can only use .str accessor with string values.*" @@ -138,7 +138,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype, request): s = Series(list("aabb"), dtype=any_string_dtype) s = s + " " + s c = s.astype("category") - assert isinstance(c.str, strings.StringMethods) + assert isinstance(c.str, StringMethods) method_name, args, kwargs = any_string_method diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 6f6acb7a996b2..1fd5c00bf3764 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -53,10 +53,7 @@ def test_contains(any_string_dtype): np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object), dtype=any_string_dtype, ) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = values.str.contains("FOO|mmm", case=False) + result = values.str.contains("FOO|mmm", case=False) expected = Series(np.array([True, False, True, True]), dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -172,10 +169,7 @@ def test_contains_moar(any_string_dtype): ) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = s.str.contains("a", case=False) + result = s.str.contains("a", case=False) expected = Series( [True, False, False, True, True, False, np.nan, True, False, True], dtype=expected_dtype, @@ -196,10 +190,7 @@ def test_contains_moar(any_string_dtype): ) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = s.str.contains("ba", case=False) + result = s.str.contains("ba", case=False) expected = Series( [False, False, False, True, True, False, np.nan, True, False, False], dtype=expected_dtype, @@ -715,10 +706,7 @@ def test_match_na_kwarg(any_string_dtype): def test_match_case_kwarg(any_string_dtype): values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = values.str.match("ab", case=False) + result = values.str.match("ab", case=False) expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean" expected = Series([True, True, True, True], dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -761,10 +749,7 @@ def test_fullmatch_case_kwarg(any_string_dtype): expected = Series([True, True, False, False], dtype=expected_dtype) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = ser.str.fullmatch("ab", case=False) + result = ser.str.fullmatch("ab", case=False) tm.assert_series_equal(result, expected) with tm.maybe_produces_warning( @@ -842,7 +827,10 @@ def test_find(any_string_dtype): expected = np.array([v.rfind("EF") for v in np.array(ser)], dtype=np.int64) tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) - result = ser.str.find("EF", 3) + with tm.maybe_produces_warning( + PerformanceWarning, any_string_dtype == "string[pyarrow]" + ): + result = ser.str.find("EF", 3) expected = Series([4, 3, 7, 4, -1], dtype=expected_dtype) tm.assert_series_equal(result, expected) expected = np.array([v.find("EF", 3) for v in np.array(ser)], dtype=np.int64) @@ -890,7 +878,10 @@ def test_find_nan(any_string_dtype): expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) tm.assert_series_equal(result, expected) - result = ser.str.find("EF", 3) + with tm.maybe_produces_warning( + PerformanceWarning, any_string_dtype == "string[pyarrow]" + ): + result = ser.str.find("EF", 3) expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) tm.assert_series_equal(result, expected) diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index a9335e156d9db..4a296f197e39f 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -599,7 +599,7 @@ def test_normalize_index(): ], ) def test_index_str_accessor_visibility(values, inferred_type, index_or_series): - from pandas.core.strings import StringMethods + from pandas.core.strings.accessor import StringMethods obj = index_or_series(values) if index_or_series is Index: