From e3780b0026449f6b97a87d05365a6bb8d9dccbfe Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 6 Dec 2022 18:02:43 -0800 Subject: [PATCH 01/12] start adding string methods --- pandas/core/arrays/arrow/array.py | 219 +++++++++++++++++++++++++++++- pandas/core/indexes/base.py | 2 +- pandas/core/series.py | 2 +- pandas/core/strings/__init__.py | 8 +- 4 files changed, 224 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 254ff8894b36c..8ecd1b263ea2e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,14 +1,19 @@ from __future__ import annotations +import re from typing import ( TYPE_CHECKING, Any, + Callable, TypeVar, + Literal, + Sequence, cast, ) import numpy as np +from pandas._libs import missing as libmissing from pandas._typing import ( ArrayLike, Dtype, @@ -16,6 +21,7 @@ Iterator, PositionalIndexer, SortKind, + Scalar, TakeIndexer, npt, ) @@ -42,6 +48,7 @@ unpack_tuple_and_ellipses, validate_indices, ) +from pandas.core.strings.object_array import ObjectStringArrayMixin if not pa_version_under6p0: import pyarrow as pa @@ -133,7 +140,7 @@ def to_pyarrow_type( return pa_dtype -class ArrowExtensionArray(OpsMixin, ExtensionArray): +class ArrowExtensionArray(OpsMixin, ObjectStringArrayMixin, ExtensionArray): """ Pandas ExtensionArray backed by a PyArrow ChunkedArray. @@ -193,6 +200,7 @@ def __init__(self, values: pa.Array | pa.ChunkedArray) -> None: f"Unsupported type '{type(values)}' for ArrowExtensionArray" ) self._dtype = ArrowDtype(self._data.type) + assert self._dtype.na_value is self._str_na_value @classmethod def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): @@ -1090,3 +1098,212 @@ def _replace_with_indices( return pc.if_else(mask, None, chunk) return pc.replace_with_mask(chunk, mask, value) + + # ------------------------------------------------------------------------ + # String methods interface (Series.str.) + + # asserted to match self._dtype.na_value + _str_na_value = libmissing.NA + + def _str_count(self, pat: str, flags: int = 0): + if flags: + fallback_performancewarning() + return super()._str_count(pat, flags) + return type(self)(pc.count_substring_regex(self._data, pat)) + + def _str_pad( + self, + width: int, + side: Literal["left", "right", "both"] = "left", + fillchar: str = " ", + ): + if side == "left": + pa_pad = pc.utf8_lpad + elif side == "right": + pa_pad = pc.utf8_rpad + elif side == "both": + pa_pad = pc.utf8_center + else: + raise ValueError(f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'") + return type(self)(pa_pad(self._data, fillchar)) + + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): + # TODO: de-duplicate with StringArray method. This method is moreless copy and + # paste. + + from pandas.arrays import ( + BooleanArray, + IntegerArray, + ) + + if dtype is None: + dtype = self.dtype + if na_value is None: + na_value = self.dtype.na_value + + mask = isna(self) + arr = np.asarray(self) + + if is_integer_dtype(dtype) or is_bool_dtype(dtype): + constructor: type[IntegerArray] | type[BooleanArray] + if is_integer_dtype(dtype): + constructor = IntegerArray + else: + constructor = BooleanArray + + na_value_is_na = isna(na_value) + if na_value_is_na: + na_value = 1 + result = lib.map_infer_mask( + arr, + f, + mask.view("uint8"), + convert=False, + na_value=na_value, + # error: Argument 1 to "dtype" has incompatible type + # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected + # "Type[object]" + dtype=np.dtype(dtype), # type: ignore[arg-type] + ) + + if not na_value_is_na: + mask[:] = False + + return constructor(result, mask) + + elif is_string_dtype(dtype) and not is_object_dtype(dtype): + # i.e. StringDtype + result = lib.map_infer_mask( + arr, f, mask.view("uint8"), convert=False, na_value=na_value + ) + result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True) + return type(self)(result) + else: + # This is when the result type is object. We reach this when + # -> We know the result type is truly object (e.g. .encode returns bytes + # or .findall returns a list). + # -> We don't know the result type. E.g. `.get` can return anything. + return lib.map_infer_mask(arr, f, mask.view("uint8")) + + def _str_contains( + self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True + ): + if flags: + fallback_performancewarning() + return super()._str_contains(pat, case, flags, na, regex) + + if regex: + pa_contains = pc.match_substring_regex + else: + pa_contains = pc.match_substring + result = pa_contains(self._data, pat, ignore_case=not case) + if not isna(na): + result = result.fill_null(na) + return type(self)(result) + + def _str_startswith(self, pat: str, na=None): + return type(self)(pc.starts_with(self._data, pat)) + + def _str_endswith(self, pat: str, na=None): + return type(self)(pc.ends_with(self._data, pat)) + + def _str_replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool = True, + flags: int = 0, + regex: bool = True, + ): + if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: + fallback_performancewarning() + return super()._str_replace(pat, repl, n, case, flags, regex) + + func = pc.replace_substring_regex if regex else pc.replace_substring + result = func(self._data, pattern=pat, replacement=repl, max_replacements=n) + return type(self)(result) + + def _str_repeat(self, repeats: int | Sequence[int]): + if not isinstance(repeats, int): + fallback_performancewarning() + return super()._str_repeat(repeats) + elif pa_version_under7p0: + fallback_performancewarning("7") + return super()._str_repeat(repeats) + else: + return type(self)(pc.binary_repeat(self._data, repeats)) + + def _str_match( + self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.startswith("^"): + pat = f"^{pat}" + return self._str_contains(pat, case, flags, na, regex=True) + + def _str_fullmatch( + self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.endswith("$") or pat.endswith("//$"): + pat = f"{pat}$" + return self._str_match(pat, case, flags, na) + + def _str_isalnum(self): + return type(self)(pc.utf8_is_alnum(self._data)) + + def _str_isalpha(self): + return type(self)(pc.utf8_is_alpha(self._data)) + + def _str_isdecimal(self): + return type(self)(pc.utf8_is_decimal(self._data)) + + def _str_isdigit(self): + return type(self)(pc.utf8_is_digit(self._data)) + + def _str_islower(self): + return type(self)(pc.utf8_is_lower(self._data)) + + def _str_isnumeric(self): + return type(self)(pc.utf8_is_numeric(self._data)) + + def _str_isspace(self): + return type(self)(pc.utf8_is_space(self._data)) + + def _str_istitle(self): + return type(self)(pc.utf8_is_title(self._data)) + + def _str_isupper(self): + return type(self)(pc.utf8_is_upper(self._data)) + + def _str_len(self): + return type(self)(pc.utf8_length(self._data)) + + def _str_lower(self): + return type(self)(pc.utf8_lower(self._data)) + + def _str_upper(self): + return type(self)(pc.utf8_upper(self._data)) + + def _str_strip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_trim_whitespace(self._data) + else: + result = pc.utf8_trim(self._data, characters=to_strip) + return type(self)(result) + + def _str_lstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_ltrim_whitespace(self._data) + else: + result = pc.utf8_ltrim(self._data, characters=to_strip) + return type(self)(result) + + def _str_rstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_rtrim_whitespace(self._data) + else: + result = pc.utf8_rtrim(self._data, characters=to_strip) + return type(self)(result) + diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 8e41a7018c670..f50f0cbd4c32b 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -169,7 +169,7 @@ get_group_index_sorter, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.io.formats.printing import ( PrettyDict, diff --git a/pandas/core/series.py b/pandas/core/series.py index 1e5f565934b50..cfa51dc6916cb 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -148,7 +148,7 @@ ensure_key_mapped, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.core.tools.datetimes import to_datetime import pandas.io.formats.format as fmt diff --git a/pandas/core/strings/__init__.py b/pandas/core/strings/__init__.py index 28aba7c9ce0b3..9accf288de3ac 100644 --- a/pandas/core/strings/__init__.py +++ b/pandas/core/strings/__init__.py @@ -27,7 +27,7 @@ # - Categorical # - ArrowStringArray -from pandas.core.strings.accessor import StringMethods -from pandas.core.strings.base import BaseStringArrayMethods - -__all__ = ["StringMethods", "BaseStringArrayMethods"] +# from pandas.core.strings.accessor import StringMethods +# from pandas.core.strings.base import BaseStringArrayMethods +# +# __all__ = ["StringMethods", "BaseStringArrayMethods"] From fa836348875b0cd4e13613f726d79c2a80c38d2f Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 8 Dec 2022 19:20:45 -0800 Subject: [PATCH 02/12] Finish adding methods --- pandas/core/arrays/arrow/array.py | 72 +++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 8ecd1b263ea2e..f79c85f2a3a4f 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -5,9 +5,9 @@ TYPE_CHECKING, Any, Callable, - TypeVar, Literal, Sequence, + TypeVar, cast, ) @@ -20,8 +20,8 @@ FillnaOptions, Iterator, PositionalIndexer, - SortKind, Scalar, + SortKind, TakeIndexer, npt, ) @@ -1124,7 +1124,9 @@ def _str_pad( elif side == "both": pa_pad = pc.utf8_center else: - raise ValueError(f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'") + raise ValueError( + f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" + ) return type(self)(pa_pad(self._data, fillchar)) def _str_map( @@ -1250,6 +1252,50 @@ def _str_fullmatch( pat = f"{pat}$" return self._str_match(pat, case, flags, na) + def _str_encode(self, encoding: str, errors: str = "strict"): + if errors != "strict" or encoding.lower() != "utf-8": + fallback_performancewarning() + super()._str_encode(encoding, errors) + return type(self)(self._data.cast(pa.binary())) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 or end is not None: + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) + else: + slices = self._data + return type(self)(pc.find_substring(slices, sub)) + + def _str_get(self, i: int): + lengths = pc.utf8_length(self._data) + if i > 0: + out_of_bounds = pc.greater_equal(lengths, i) + start = i + stop = i + 1 + step = 1 + else: + out_of_bounds = pc.greater(lengths, -i) + start = i + stop = i - 1 + step = -1 + selected = pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) + masked_selected = pc.replace_with_mask(selected, out_of_bounds, None) + return type(self)(masked_selected) + + def _str_join(self, sep: str): + return type(self)(pc.binary_join(self._data, sep)) + + def _str_slice( + self, start: int | None = None, stop: int | None = None, step: int | None = None + ): + return type(self)( + pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) + ) + + def _str_slice_replace(self, start=None, stop=None, repl=None): + if repl is None: + repl = "" + return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) + def _str_isalnum(self): return type(self)(pc.utf8_is_alnum(self._data)) @@ -1274,9 +1320,18 @@ def _str_isspace(self): def _str_istitle(self): return type(self)(pc.utf8_is_title(self._data)) + def _str_capitalize(self): + return type(self)(pc.utf8_capitalize(self._data)) + + def _str_title(self): + return type(self)(pc.utf8_title(self._data)) + def _str_isupper(self): return type(self)(pc.utf8_is_upper(self._data)) + def _str_swapcase(self): + return type(self)(pc.utf8_swapcase(self._data)) + def _str_len(self): return type(self)(pc.utf8_length(self._data)) @@ -1307,3 +1362,14 @@ def _str_rstrip(self, to_strip=None): result = pc.utf8_rtrim(self._data, characters=to_strip) return type(self)(result) + def _str_removeprefix(self, prefix: str) -> Series: + starts_with = pc.starts_with(self._data, prefix) + removed = pc.utf8_slice_codeunits(self._data, len(prefix)) + result = pc.replace_with_mask(self._data, starts_with, removed) + return type(self)(result) + + def _str_removesuffix(self, suffix: str) -> Series: + ends_with = pc.ends_with(self._data, suffix) + removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) + result = pc.replace_with_mask(self._data, ends_with, removed) + return type(self)(result) From f242d812d2e61bce17fc79e504332ccf84c3a4d1 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 13 Dec 2022 18:09:10 -0800 Subject: [PATCH 03/12] Add some tests that are failing --- pandas/conftest.py | 124 +++++++++++++++++++++++++++ pandas/core/arrays/arrow/array.py | 108 ++++++++++++----------- pandas/core/arrays/arrow/dtype.py | 3 + pandas/tests/extension/test_arrow.py | 97 +++++++++++++++++++++ pandas/tests/strings/conftest.py | 124 --------------------------- pandas/tests/strings/test_strings.py | 2 +- 6 files changed, 283 insertions(+), 175 deletions(-) diff --git a/pandas/conftest.py b/pandas/conftest.py index 0d6af91d32dea..7e6228369c8ad 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -66,6 +66,7 @@ ) import pandas._testing as tm from pandas.core import ops +from pandas.core.strings.accessor import StringMethods from pandas.core.indexes.api import ( Index, MultiIndex, @@ -1921,3 +1922,126 @@ def warsaw(request): tzinfo for Europe/Warsaw using pytz, dateutil, or zoneinfo. """ return request.param + + +_any_string_method = [ + ("cat", (), {"sep": ","}), + ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}), + ("center", (10,), {}), + ("contains", ("a",), {}), + ("count", ("a",), {}), + ("decode", ("UTF-8",), {}), + ("encode", ("UTF-8",), {}), + ("endswith", ("a",), {}), + ("endswith", ("a",), {"na": True}), + ("endswith", ("a",), {"na": False}), + ("extract", ("([a-z]*)",), {"expand": False}), + ("extract", ("([a-z]*)",), {"expand": True}), + ("extractall", ("([a-z]*)",), {}), + ("find", ("a",), {}), + ("findall", ("a",), {}), + ("get", (0,), {}), + # because "index" (and "rindex") fail intentionally + # if the string is not found, search only for empty string + ("index", ("",), {}), + ("join", (",",), {}), + ("ljust", (10,), {}), + ("match", ("a",), {}), + ("fullmatch", ("a",), {}), + ("normalize", ("NFC",), {}), + ("pad", (10,), {}), + ("partition", (" ",), {"expand": False}), + ("partition", (" ",), {"expand": True}), + ("repeat", (3,), {}), + ("replace", ("a", "z"), {}), + ("rfind", ("a",), {}), + ("rindex", ("",), {}), + ("rjust", (10,), {}), + ("rpartition", (" ",), {"expand": False}), + ("rpartition", (" ",), {"expand": True}), + ("slice", (0, 1), {}), + ("slice_replace", (0, 1, "z"), {}), + ("split", (" ",), {"expand": False}), + ("split", (" ",), {"expand": True}), + ("startswith", ("a",), {}), + ("startswith", ("a",), {"na": True}), + ("startswith", ("a",), {"na": False}), + ("removeprefix", ("a",), {}), + ("removesuffix", ("a",), {}), + # translating unicode points of "a" to "d" + ("translate", ({97: 100},), {}), + ("wrap", (2,), {}), + ("zfill", (10,), {}), +] + list( + zip( + [ + # methods without positional arguments: zip with empty tuple and empty dict + "capitalize", + "cat", + "get_dummies", + "isalnum", + "isalpha", + "isdecimal", + "isdigit", + "islower", + "isnumeric", + "isspace", + "istitle", + "isupper", + "len", + "lower", + "lstrip", + "partition", + "rpartition", + "rsplit", + "rstrip", + "slice", + "slice_replace", + "split", + "strip", + "swapcase", + "title", + "upper", + "casefold", + ], + [()] * 100, + [{}] * 100, + ) +) +ids, _, _ = zip(*_any_string_method) # use method name as fixture-id +missing_methods = { + f for f in dir(StringMethods) if not f.startswith("_") +} - set(ids) + +# test that the above list captures all methods of StringMethods +assert not missing_methods + + +@pytest.fixture(params=_any_string_method, ids=ids) +def any_string_method(request): + """ + Fixture for all public methods of `StringMethods` + + This fixture returns a tuple of the method name and sample arguments + necessary to call the method. + + Returns + ------- + method_name : str + The name of the method in `StringMethods` + args : tuple + Sample values for the positional arguments + kwargs : dict + Sample values for the keyword arguments + + Examples + -------- + >>> def test_something(any_string_method): + ... s = Series(['a', 'b', np.nan, 'd']) + ... + ... method_name, args, kwargs = any_string_method + ... method = getattr(s.str, method_name) + ... # will not raise + ... method(*args, **kwargs) + """ + return request.param diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index f79c85f2a3a4f..510fcff1bebf1 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -13,7 +13,10 @@ import numpy as np -from pandas._libs import missing as libmissing +from pandas._libs import ( + lib, + missing as libmissing, +) from pandas._typing import ( ArrayLike, Dtype, @@ -1134,60 +1137,65 @@ def _str_map( ): # TODO: de-duplicate with StringArray method. This method is moreless copy and # paste. - - from pandas.arrays import ( - BooleanArray, - IntegerArray, - ) - if dtype is None: dtype = self.dtype if na_value is None: na_value = self.dtype.na_value - mask = isna(self) + mask = self.isna() arr = np.asarray(self) - if is_integer_dtype(dtype) or is_bool_dtype(dtype): - constructor: type[IntegerArray] | type[BooleanArray] - if is_integer_dtype(dtype): - constructor = IntegerArray - else: - constructor = BooleanArray - - na_value_is_na = isna(na_value) - if na_value_is_na: - na_value = 1 - result = lib.map_infer_mask( - arr, - f, - mask.view("uint8"), - convert=False, - na_value=na_value, - # error: Argument 1 to "dtype" has incompatible type - # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected - # "Type[object]" - dtype=np.dtype(dtype), # type: ignore[arg-type] - ) - - if not na_value_is_na: - mask[:] = False - - return constructor(result, mask) - - elif is_string_dtype(dtype) and not is_object_dtype(dtype): - # i.e. StringDtype - result = lib.map_infer_mask( - arr, f, mask.view("uint8"), convert=False, na_value=na_value - ) - result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True) - return type(self)(result) - else: - # This is when the result type is object. We reach this when - # -> We know the result type is truly object (e.g. .encode returns bytes - # or .findall returns a list). - # -> We don't know the result type. E.g. `.get` can return anything. - return lib.map_infer_mask(arr, f, mask.view("uint8")) + np_result = lib.map_infer_mask(arr, f, mask.view("uint8"), convert=False, na_value=na_value) + try: + return type(self)(pa.array(np_result, mask=mask, from_pandas=True)) + except pa.ArrowInvalid: + return np_result + # if is_integer_dtype(dtype) or is_bool_dtype(dtype): + # + # from pandas.arrays import ( + # BooleanArray, + # IntegerArray, + # ) + # + # constructor: type[IntegerArray] | type[BooleanArray] + # if is_integer_dtype(dtype): + # constructor = IntegerArray + # else: + # constructor = BooleanArray + # + # na_value_is_na = isna(na_value) + # if na_value_is_na: + # na_value = 1 + # result = lib.map_infer_mask( + # arr, + # f, + # mask.view("uint8"), + # convert=False, + # na_value=na_value, + # # error: Argument 1 to "dtype" has incompatible type + # # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected + # # "Type[object]" + # dtype=np.dtype(dtype), # type: ignore[arg-type] + # ) + # + # if not na_value_is_na: + # mask[:] = False + # + # return constructor(result, mask) + # + # elif is_string_dtype(dtype) and not is_object_dtype(dtype): + # # i.e. StringDtype + # result = lib.map_infer_mask( + # arr, f, mask.view("uint8"), convert=False, na_value=na_value + # ) + # result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True) + # return type(self)(result) + # else: + # # This is when the result type is object. We reach this when + # # -> We know the result type is truly object (e.g. .encode returns bytes + # # or .findall returns a list). + # # -> We don't know the result type. E.g. `.get` can return anything. + # return lib.map_infer_mask(arr, f, mask.view("uint8")) def _str_contains( self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True @@ -1255,7 +1263,7 @@ def _str_fullmatch( def _str_encode(self, encoding: str, errors: str = "strict"): if errors != "strict" or encoding.lower() != "utf-8": fallback_performancewarning() - super()._str_encode(encoding, errors) + return super()._str_encode(encoding, errors) return type(self)(self._data.cast(pa.binary())) def _str_find(self, sub: str, start: int = 0, end: int | None = None): @@ -1291,7 +1299,7 @@ def _str_slice( pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) ) - def _str_slice_replace(self, start=None, stop=None, repl=None): + def _str_slice_replace(self, start: int | None = None, stop: int | None = None, repl: str | None = None): if repl is None: repl = "" return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index f5f87bea83b8f..a60afb6156e00 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -95,6 +95,9 @@ def name(self) -> str: # type: ignore[override] @cache_readonly def numpy_dtype(self) -> np.dtype: """Return an instance of the related numpy dtype""" + if pa.types.is_string(pa.string()): + # pa.string().to_pandas_dtype() = object which we don't want + return np.dtype(str) try: return np.dtype(self.pyarrow_dtype.to_pandas_dtype()) except (NotImplementedError, TypeError): diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 557cdd96bf00c..d41c6f5e64dca 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -56,6 +56,11 @@ def dtype(request): return ArrowDtype(pyarrow_dtype=request.param) +@pytest.fixture +def nullable_string_dtype(): + return ArrowDtype(pa.string()) + + @pytest.fixture def data(dtype): pa_dtype = dtype.pyarrow_dtype @@ -1421,3 +1426,95 @@ def test_astype_from_non_pyarrow(data): assert not isinstance(pd_array.dtype, ArrowDtype) assert isinstance(result.dtype, ArrowDtype) tm.assert_extension_array_equal(result, data) + + +@pytest.mark.filterwarnings("ignore:Falling back") +def test_string_array(nullable_string_dtype, any_string_method): + method_name, args, kwargs = any_string_method + + data = ["a", "bb", np.nan, "ccc"] + a = pd.Series(data, dtype=object) + b = pd.Series(data, dtype=nullable_string_dtype) + + if method_name == "decode": + with pytest.raises(TypeError, match="a bytes-like object is required"): + getattr(b.str, method_name)(*args, **kwargs) + return + + expected = getattr(a.str, method_name)(*args, **kwargs) + result = getattr(b.str, method_name)(*args, **kwargs) + + if isinstance(expected, pd.Series): + # if expected.dtype == "object" and lib.is_string_array( + # expected.dropna().values, + # ): + # assert result.dtype == nullable_string_dtype + # result = result.astype(object) + # + # elif expected.dtype == "object" and lib.is_bool_array( + # expected.values, skipna=True + # ): + # assert result.dtype == "boolean" + # result = result.astype(object) + + if expected.dtype == "bool": + assert result.dtype == "boolean" + result = result.astype("bool") + + elif expected.dtype == "float" and expected.isna().any(): + assert result.dtype == "Int64" + result = result.astype("float") + + elif isinstance(expected, pd.DataFrame): + columns = expected.select_dtypes(include="object").columns + assert all(result[columns].dtypes == nullable_string_dtype) + result[columns] = result[columns].astype(object) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ("count", [2, None]), + ("find", [0, None]), + ("index", [0, None]), + ("rindex", [2, None]), + ], +) +def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected): + s = pd.Series(["aba", None], dtype=nullable_string_dtype) + result = getattr(s.str, method)("a") + expected = pd.Series(expected, dtype="Int64") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ("isdigit", [False, None, True]), + ("isalpha", [True, None, False]), + ("isalnum", [True, None, True]), + ("isnumeric", [False, None, True]), + ], +) +def test_string_array_boolean_array(nullable_string_dtype, method, expected): + s = pd.Series(["a", None, "1"], dtype=nullable_string_dtype) + result = getattr(s.str, method)() + expected = pd.Series(expected, dtype="boolean") + tm.assert_series_equal(result, expected) + + +def test_string_array_extract(nullable_string_dtype): + # https://github.com/pandas-dev/pandas/issues/30969 + # Only expand=False & multiple groups was failing + + a = pd.Series(["a1", "b2", "cc"], dtype=nullable_string_dtype) + b = pd.Series(["a1", "b2", "cc"], dtype="object") + pat = r"(\w)(\d)" + + result = a.str.extract(pat, expand=False) + expected = b.str.extract(pat, expand=False) + assert all(result.dtypes == nullable_string_dtype) + + result = result.astype(object) + tm.assert_equal(result, expected) \ No newline at end of file diff --git a/pandas/tests/strings/conftest.py b/pandas/tests/strings/conftest.py index cdc2b876194e6..c1d58babcf9aa 100644 --- a/pandas/tests/strings/conftest.py +++ b/pandas/tests/strings/conftest.py @@ -2,130 +2,6 @@ import pytest from pandas import Series -from pandas.core import strings - -_any_string_method = [ - ("cat", (), {"sep": ","}), - ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}), - ("center", (10,), {}), - ("contains", ("a",), {}), - ("count", ("a",), {}), - ("decode", ("UTF-8",), {}), - ("encode", ("UTF-8",), {}), - ("endswith", ("a",), {}), - ("endswith", ("a",), {"na": True}), - ("endswith", ("a",), {"na": False}), - ("extract", ("([a-z]*)",), {"expand": False}), - ("extract", ("([a-z]*)",), {"expand": True}), - ("extractall", ("([a-z]*)",), {}), - ("find", ("a",), {}), - ("findall", ("a",), {}), - ("get", (0,), {}), - # because "index" (and "rindex") fail intentionally - # if the string is not found, search only for empty string - ("index", ("",), {}), - ("join", (",",), {}), - ("ljust", (10,), {}), - ("match", ("a",), {}), - ("fullmatch", ("a",), {}), - ("normalize", ("NFC",), {}), - ("pad", (10,), {}), - ("partition", (" ",), {"expand": False}), - ("partition", (" ",), {"expand": True}), - ("repeat", (3,), {}), - ("replace", ("a", "z"), {}), - ("rfind", ("a",), {}), - ("rindex", ("",), {}), - ("rjust", (10,), {}), - ("rpartition", (" ",), {"expand": False}), - ("rpartition", (" ",), {"expand": True}), - ("slice", (0, 1), {}), - ("slice_replace", (0, 1, "z"), {}), - ("split", (" ",), {"expand": False}), - ("split", (" ",), {"expand": True}), - ("startswith", ("a",), {}), - ("startswith", ("a",), {"na": True}), - ("startswith", ("a",), {"na": False}), - ("removeprefix", ("a",), {}), - ("removesuffix", ("a",), {}), - # translating unicode points of "a" to "d" - ("translate", ({97: 100},), {}), - ("wrap", (2,), {}), - ("zfill", (10,), {}), -] + list( - zip( - [ - # methods without positional arguments: zip with empty tuple and empty dict - "capitalize", - "cat", - "get_dummies", - "isalnum", - "isalpha", - "isdecimal", - "isdigit", - "islower", - "isnumeric", - "isspace", - "istitle", - "isupper", - "len", - "lower", - "lstrip", - "partition", - "rpartition", - "rsplit", - "rstrip", - "slice", - "slice_replace", - "split", - "strip", - "swapcase", - "title", - "upper", - "casefold", - ], - [()] * 100, - [{}] * 100, - ) -) -ids, _, _ = zip(*_any_string_method) # use method name as fixture-id -missing_methods = { - f for f in dir(strings.StringMethods) if not f.startswith("_") -} - set(ids) - -# test that the above list captures all methods of StringMethods -assert not missing_methods - - -@pytest.fixture(params=_any_string_method, ids=ids) -def any_string_method(request): - """ - Fixture for all public methods of `StringMethods` - - This fixture returns a tuple of the method name and sample arguments - necessary to call the method. - - Returns - ------- - method_name : str - The name of the method in `StringMethods` - args : tuple - Sample values for the positional arguments - kwargs : dict - Sample values for the keyword arguments - - Examples - -------- - >>> def test_something(any_string_method): - ... s = Series(['a', 'b', np.nan, 'd']) - ... - ... method_name, args, kwargs = any_string_method - ... method = getattr(s.str, method_name) - ... # will not raise - ... method(*args, **kwargs) - """ - return request.param - # subset of the full set from pandas/conftest.py _any_allowed_skipna_inferred_dtype = [ diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 4385f71dc653f..8fbff46451539 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -606,7 +606,7 @@ def test_normalize_index(): ], ) def test_index_str_accessor_visibility(values, inferred_type, index_or_series): - from pandas.core.strings import StringMethods + from pandas.core.strings.accessor import StringMethods obj = index_or_series(values) if index_or_series is Index: From def31cc44d9fff30488ac133c4e288a26ded3d8b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 14 Dec 2022 19:25:57 -0800 Subject: [PATCH 04/12] More test adjustment --- pandas/conftest.py | 22 +++++----- pandas/core/arrays/arrow/array.py | 35 +++++++++++----- pandas/core/strings/accessor.py | 3 +- pandas/tests/extension/test_arrow.py | 62 ++++++++++++++++------------ 4 files changed, 74 insertions(+), 48 deletions(-) diff --git a/pandas/conftest.py b/pandas/conftest.py index 7e6228369c8ad..526ccdfe2c7fd 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -1950,24 +1950,24 @@ def warsaw(request): ("fullmatch", ("a",), {}), ("normalize", ("NFC",), {}), ("pad", (10,), {}), - ("partition", (" ",), {"expand": False}), - ("partition", (" ",), {"expand": True}), + # ("partition", (" ",), {"expand": False}), + # ("partition", (" ",), {"expand": True}), ("repeat", (3,), {}), ("replace", ("a", "z"), {}), ("rfind", ("a",), {}), ("rindex", ("",), {}), ("rjust", (10,), {}), - ("rpartition", (" ",), {"expand": False}), - ("rpartition", (" ",), {"expand": True}), + # ("rpartition", (" ",), {"expand": False}), + # ("rpartition", (" ",), {"expand": True}), ("slice", (0, 1), {}), ("slice_replace", (0, 1, "z"), {}), ("split", (" ",), {"expand": False}), - ("split", (" ",), {"expand": True}), + # ("split", (" ",), {"expand": True}), ("startswith", ("a",), {}), ("startswith", ("a",), {"na": True}), ("startswith", ("a",), {"na": False}), - ("removeprefix", ("a",), {}), - ("removesuffix", ("a",), {}), + # ("removeprefix", ("a",), {}), + # ("removesuffix", ("a",), {}), # translating unicode points of "a" to "d" ("translate", ({97: 100},), {}), ("wrap", (2,), {}), @@ -1991,11 +1991,11 @@ def warsaw(request): "len", "lower", "lstrip", - "partition", - "rpartition", + # "partition", + # "rpartition", "rsplit", "rstrip", - "slice", + # "slice", "slice_replace", "split", "strip", @@ -2014,7 +2014,7 @@ def warsaw(request): } - set(ids) # test that the above list captures all methods of StringMethods -assert not missing_methods +#assert not missing_methods @pytest.fixture(params=_any_string_method, ids=ids) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 510fcff1bebf1..b826fcfb2296e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1130,7 +1130,7 @@ def _str_pad( raise ValueError( f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" ) - return type(self)(pa_pad(self._data, fillchar)) + return type(self)(pa_pad(self._data, width, padding=fillchar)) def _str_map( self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True @@ -1214,10 +1214,16 @@ def _str_contains( return type(self)(result) def _str_startswith(self, pat: str, na=None): - return type(self)(pc.starts_with(self._data, pat)) + result = pc.starts_with(self._data, pat) + if na is not None: + result = result.fill_null(na) + return type(self)(result) def _str_endswith(self, pat: str, na=None): - return type(self)(pc.ends_with(self._data, pat)) + result = pc.ends_with(self._data, pat) + if na is not None: + result = result.fill_null(na) + return type(self)(result) def _str_replace( self, @@ -1275,26 +1281,35 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): def _str_get(self, i: int): lengths = pc.utf8_length(self._data) - if i > 0: - out_of_bounds = pc.greater_equal(lengths, i) + if i >= 0: + out_of_bounds = pc.greater_equal(i, lengths) start = i stop = i + 1 step = 1 else: - out_of_bounds = pc.greater(lengths, -i) + out_of_bounds = pc.greater(-i, lengths) start = i stop = i - 1 step = -1 - selected = pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) - masked_selected = pc.replace_with_mask(selected, out_of_bounds, None) - return type(self)(masked_selected) + not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True).combine_chunks()) + selected = pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step).combine_chunks().drop_null() + result = pa.array([None] * self._data.length(), type=self._data.type) + result = pc.replace_with_mask(result, not_out_of_bounds, selected) + return type(self)(result) def _str_join(self, sep: str): - return type(self)(pc.binary_join(self._data, sep)) + if pa.types.is_list(self.dtype.pyarrow_dtype): + return type(self)(pc.binary_join(self._data, sep)) + else: + return super()._str_join(sep) def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None ): + if start is None: + start = 0 + if step is None: + step = 1 return type(self)( pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) ) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 05fbb68e1f19b..3d08b78acfc12 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3266,8 +3266,9 @@ def _result_dtype(arr): # ideally we just pass `dtype=arr.dtype` unconditionally, but this fails # when the list of values is empty. from pandas.core.arrays.string_ import StringDtype + from pandas.core.arrays.arrow import ArrowDtype - if isinstance(arr.dtype, StringDtype): + if isinstance(arr.dtype, (ArrowDtype, StringDtype)): return arr.dtype else: return object diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index d41c6f5e64dca..0995c8122e146 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -25,6 +25,8 @@ import numpy as np import pytest +from pandas._libs import lib + from pandas.compat import ( is_ci_environment, is_platform_windows, @@ -1429,7 +1431,7 @@ def test_astype_from_non_pyarrow(data): @pytest.mark.filterwarnings("ignore:Falling back") -def test_string_array(nullable_string_dtype, any_string_method): +def test_string_methods(nullable_string_dtype, any_string_method): method_name, args, kwargs = any_string_method data = ["a", "bb", np.nan, "ccc"] @@ -1445,25 +1447,33 @@ def test_string_array(nullable_string_dtype, any_string_method): result = getattr(b.str, method_name)(*args, **kwargs) if isinstance(expected, pd.Series): - # if expected.dtype == "object" and lib.is_string_array( - # expected.dropna().values, - # ): - # assert result.dtype == nullable_string_dtype - # result = result.astype(object) - # - # elif expected.dtype == "object" and lib.is_bool_array( - # expected.values, skipna=True - # ): - # assert result.dtype == "boolean" - # result = result.astype(object) - - if expected.dtype == "bool": - assert result.dtype == "boolean" + if expected.dtype == "object" and lib.is_string_array( + expected.dropna().values, + ): + assert result.dtype == nullable_string_dtype + result = result.astype(object) + + elif expected.dtype == "object" and lib.is_bool_array( + expected.values, skipna=True + ): + assert result.dtype == ArrowDtype(pa.bool_()) + result = result.astype(object) + + elif expected.dtype == "bool": + assert result.dtype == ArrowDtype(pa.bool_()) result = result.astype("bool") elif expected.dtype == "float" and expected.isna().any(): - assert result.dtype == "Int64" - result = result.astype("float") + assert isinstance(result.dtype, ArrowDtype) + assert pa.types.is_integer(result.dtype.pyarrow_dtype) + result = result.astype("Int64") + expected = expected.astype("Int64") + + elif pa.types.is_binary(result.dtype.pyarrow_dtype): + result = result.astype(object) + + elif pa.types.is_list(result.dtype.pyarrow_dtype): + result = result.astype(object) elif isinstance(expected, pd.DataFrame): columns = expected.select_dtypes(include="object").columns @@ -1473,18 +1483,18 @@ def test_string_array(nullable_string_dtype, any_string_method): @pytest.mark.parametrize( - "method,expected", + "method,expected,return_type", [ - ("count", [2, None]), - ("find", [0, None]), - ("index", [0, None]), - ("rindex", [2, None]), + ("count", [2, None], "int32"), + ("find", [0, None], "int32"), + ("index", [0, None], "int64"), + ("rindex", [2, None], "int64"), ], ) -def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected): +def test_string_array_numeric_integer_return(nullable_string_dtype, method, expected, return_type): s = pd.Series(["aba", None], dtype=nullable_string_dtype) result = getattr(s.str, method)("a") - expected = pd.Series(expected, dtype="Int64") + expected = pd.Series(expected, dtype=ArrowDtype(getattr(pa, return_type)())) tm.assert_series_equal(result, expected) @@ -1497,10 +1507,10 @@ def test_string_array_numeric_integer_array(nullable_string_dtype, method, expec ("isnumeric", [False, None, True]), ], ) -def test_string_array_boolean_array(nullable_string_dtype, method, expected): +def test_string_array_boolean_return(nullable_string_dtype, method, expected): s = pd.Series(["a", None, "1"], dtype=nullable_string_dtype) result = getattr(s.str, method)() - expected = pd.Series(expected, dtype="boolean") + expected = pd.Series(expected, dtype=ArrowDtype(pa.bool_())) tm.assert_series_equal(result, expected) From e6f5cebb0652e19fa6724e35b5bbf4a557326450 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 16 Dec 2022 19:31:09 -0800 Subject: [PATCH 05/12] All tests passing --- pandas/conftest.py | 26 ++++++++++------------ pandas/core/arrays/arrow/array.py | 37 ++++++++++++++++++++++++++----- pandas/core/strings/accessor.py | 9 +++++--- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/pandas/conftest.py b/pandas/conftest.py index 526ccdfe2c7fd..ac079ae300516 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -66,11 +66,11 @@ ) import pandas._testing as tm from pandas.core import ops -from pandas.core.strings.accessor import StringMethods from pandas.core.indexes.api import ( Index, MultiIndex, ) +from pandas.core.strings.accessor import StringMethods try: import pyarrow as pa @@ -1950,15 +1950,15 @@ def warsaw(request): ("fullmatch", ("a",), {}), ("normalize", ("NFC",), {}), ("pad", (10,), {}), - # ("partition", (" ",), {"expand": False}), - # ("partition", (" ",), {"expand": True}), + ("partition", (" ",), {"expand": False}), + ("partition", (" ",), {"expand": True}), ("repeat", (3,), {}), ("replace", ("a", "z"), {}), ("rfind", ("a",), {}), ("rindex", ("",), {}), ("rjust", (10,), {}), - # ("rpartition", (" ",), {"expand": False}), - # ("rpartition", (" ",), {"expand": True}), + ("rpartition", (" ",), {"expand": False}), + ("rpartition", (" ",), {"expand": True}), ("slice", (0, 1), {}), ("slice_replace", (0, 1, "z"), {}), ("split", (" ",), {"expand": False}), @@ -1966,8 +1966,8 @@ def warsaw(request): ("startswith", ("a",), {}), ("startswith", ("a",), {"na": True}), ("startswith", ("a",), {"na": False}), - # ("removeprefix", ("a",), {}), - # ("removesuffix", ("a",), {}), + ("removeprefix", ("a",), {}), + ("removesuffix", ("a",), {}), # translating unicode points of "a" to "d" ("translate", ({97: 100},), {}), ("wrap", (2,), {}), @@ -1991,11 +1991,11 @@ def warsaw(request): "len", "lower", "lstrip", - # "partition", - # "rpartition", + "partition", + "rpartition", "rsplit", "rstrip", - # "slice", + "slice", "slice_replace", "split", "strip", @@ -2009,12 +2009,10 @@ def warsaw(request): ) ) ids, _, _ = zip(*_any_string_method) # use method name as fixture-id -missing_methods = { - f for f in dir(StringMethods) if not f.startswith("_") -} - set(ids) +missing_methods = {f for f in dir(StringMethods) if not f.startswith("_")} - set(ids) # test that the above list captures all methods of StringMethods -#assert not missing_methods +assert not missing_methods @pytest.fixture(params=_any_string_method, ids=ids) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 7eb4a061659ed..b0717367755ae 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1342,9 +1342,27 @@ def _str_join(self, sep: str): else: return super()._str_join(sep) + def _str_partition(self, sep: str, expand: bool): + result = super()._str_partition(sep, expand) + if expand: + # StringMethods._wrap_result needs numpy-like nested object + result = result.to_numpy(dtype=object) + return result + + def _str_rpartition(self, sep: str, expand: bool): + result = super()._str_rpartition(sep, expand) + if expand: + # StringMethods._wrap_result needs numpy-like nested object + result = result.to_numpy(dtype=object) + return result + def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None ): + if stop is None: + # TODO: TODO: Should work once https://github.com/apache/arrow/issues/14991 + # is fixed + return super()._str_slice(start, stop, step) if start is None: start = 0 if step is None: @@ -1356,8 +1374,14 @@ def _str_slice( def _str_slice_replace( self, start: int | None = None, stop: int | None = None, repl: str | None = None ): + if stop is None: + # Not implemented as of pyarrow=10 + fallback_performancewarning() + return super()._str_slice_replace(start, stop, repl) if repl is None: repl = "" + if start is None: + start = 0 return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) def _str_isalnum(self): @@ -1426,14 +1450,15 @@ def _str_rstrip(self, to_strip=None): result = pc.utf8_rtrim(self._data, characters=to_strip) return type(self)(result) - def _str_removeprefix(self, prefix: str) -> Series: - starts_with = pc.starts_with(self._data, prefix) - removed = pc.utf8_slice_codeunits(self._data, len(prefix)) - result = pc.replace_with_mask(self._data, starts_with, removed) - return type(self)(result) + # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed + # def _str_removeprefix(self, prefix: str) -> Series: + # starts_with = pc.starts_with(self._data, prefix) + # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) + # result = pc.if_else(starts_with, removed, self._data) + # return type(self)(result) def _str_removesuffix(self, suffix: str) -> Series: ends_with = pc.ends_with(self._data, suffix) removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) - result = pc.replace_with_mask(self._data, ends_with, removed) + result = pc.if_else(ends_with, removed, self._data) return type(self)(result) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 3d08b78acfc12..78213d33efbe6 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -175,11 +175,14 @@ class StringMethods(NoNewAttributesMixin): # * extractall def __init__(self, data) -> None: + from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.arrays.string_ import StringDtype self._inferred_dtype = self._validate(data) self._is_categorical = is_categorical_dtype(data.dtype) - self._is_string = isinstance(data.dtype, StringDtype) + self._is_string = isinstance(data.dtype, StringDtype) or ( + isinstance(data.dtype, ArrowDtype) and data.dtype.kind == "U" + ) self._data = data self._index = self._name = None @@ -286,7 +289,7 @@ def cons_row(x): # propagate nan values to match longest sequence (GH 18450) max_len = max(len(x) for x in result) result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result + x * max_len if len(x) == 0 or isna(x[0]) else x for x in result ] if not isinstance(expand, bool): @@ -3265,8 +3268,8 @@ def _result_dtype(arr): # workaround #27953 # ideally we just pass `dtype=arr.dtype` unconditionally, but this fails # when the list of values is empty. - from pandas.core.arrays.string_ import StringDtype from pandas.core.arrays.arrow import ArrowDtype + from pandas.core.arrays.string_ import StringDtype if isinstance(arr.dtype, (ArrowDtype, StringDtype)): return arr.dtype From adb70b8c1e6cbe7881b07f8add402287ccd4babe Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Sat, 17 Dec 2022 13:07:06 -0800 Subject: [PATCH 06/12] Add whatsnew --- doc/source/whatsnew/v2.0.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index c58ee3818cbd9..a03a63a2e0c50 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -95,7 +95,7 @@ Other enhancements - Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`) - Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`) - :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`) -- +- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string()`` (:issue:`?`) .. --------------------------------------------------------------------------- .. _whatsnew_200.notable_bug_fixes: From 0af8223268fa7d317657919eac895e5bac364dc2 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Sat, 17 Dec 2022 13:10:11 -0800 Subject: [PATCH 07/12] Remove commented out code --- pandas/core/arrays/arrow/array.py | 50 ------------------------------- pandas/core/strings/__init__.py | 6 +--- 2 files changed, 1 insertion(+), 55 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index ec80550b426cc..f9b5673a94c61 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1258,10 +1258,6 @@ def _str_pad( def _str_map( self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True ): - # TODO: de-duplicate with StringArray method. This method is moreless copy and - # paste. - if dtype is None: - dtype = self.dtype if na_value is None: na_value = self.dtype.na_value @@ -1275,52 +1271,6 @@ def _str_map( return type(self)(pa.array(np_result, mask=mask, from_pandas=True)) except pa.ArrowInvalid: return np_result - # if is_integer_dtype(dtype) or is_bool_dtype(dtype): - # - # from pandas.arrays import ( - # BooleanArray, - # IntegerArray, - # ) - # - # constructor: type[IntegerArray] | type[BooleanArray] - # if is_integer_dtype(dtype): - # constructor = IntegerArray - # else: - # constructor = BooleanArray - # - # na_value_is_na = isna(na_value) - # if na_value_is_na: - # na_value = 1 - # result = lib.map_infer_mask( - # arr, - # f, - # mask.view("uint8"), - # convert=False, - # na_value=na_value, - # # error: Argument 1 to "dtype" has incompatible type - # # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected - # # "Type[object]" - # dtype=np.dtype(dtype), # type: ignore[arg-type] - # ) - # - # if not na_value_is_na: - # mask[:] = False - # - # return constructor(result, mask) - # - # elif is_string_dtype(dtype) and not is_object_dtype(dtype): - # # i.e. StringDtype - # result = lib.map_infer_mask( - # arr, f, mask.view("uint8"), convert=False, na_value=na_value - # ) - # result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True) - # return type(self)(result) - # else: - # # This is when the result type is object. We reach this when - # # -> We know the result type is truly object (e.g. .encode returns bytes - # # or .findall returns a list). - # # -> We don't know the result type. E.g. `.get` can return anything. - # return lib.map_infer_mask(arr, f, mask.view("uint8")) def _str_contains( self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True diff --git a/pandas/core/strings/__init__.py b/pandas/core/strings/__init__.py index 9accf288de3ac..3f1d2615bbeba 100644 --- a/pandas/core/strings/__init__.py +++ b/pandas/core/strings/__init__.py @@ -26,8 +26,4 @@ # - PandasArray # - Categorical # - ArrowStringArray - -# from pandas.core.strings.accessor import StringMethods -# from pandas.core.strings.base import BaseStringArrayMethods -# -# __all__ = ["StringMethods", "BaseStringArrayMethods"] +# - ArrowExtensionArray From 5b22ccdfde70046d941e11260016a49844fa82b3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Sat, 17 Dec 2022 13:18:21 -0800 Subject: [PATCH 08/12] Add whatsnew note --- doc/source/whatsnew/v2.0.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index a03a63a2e0c50..347309739abb8 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -95,7 +95,7 @@ Other enhancements - Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`) - Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`) - :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`) -- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string()`` (:issue:`?`) +- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string()`` (:issue:`50325`) .. --------------------------------------------------------------------------- .. _whatsnew_200.notable_bug_fixes: From 19ae7b5794de86cb83de3a9b79681b4561965398 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 19 Dec 2022 12:46:51 -0800 Subject: [PATCH 09/12] Fix bug --- doc/source/whatsnew/v2.0.0.rst | 2 +- pandas/core/arrays/arrow/dtype.py | 2 +- pandas/tests/strings/test_api.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index b9e00451f241c..a69af0538ae06 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -98,7 +98,7 @@ Other enhancements - Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`) - Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`) - :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`) -- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string()`` (:issue:`50325`) +- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string())`` (:issue:`50325`) .. --------------------------------------------------------------------------- .. _whatsnew_200.notable_bug_fixes: diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index a60afb6156e00..3e3213b48670f 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -95,7 +95,7 @@ def name(self) -> str: # type: ignore[override] @cache_readonly def numpy_dtype(self) -> np.dtype: """Return an instance of the related numpy dtype""" - if pa.types.is_string(pa.string()): + if pa.types.is_string(self.pyarrow_dtype): # pa.string().to_pandas_dtype() = object which we don't want return np.dtype(str) try: diff --git a/pandas/tests/strings/test_api.py b/pandas/tests/strings/test_api.py index 7a6c7e69047bc..036eaf5f0cae4 100644 --- a/pandas/tests/strings/test_api.py +++ b/pandas/tests/strings/test_api.py @@ -8,14 +8,14 @@ _testing as tm, get_option, ) -from pandas.core import strings +from pandas.core.strings.accessor import StringMethods def test_api(any_string_dtype): # GH 6106, GH 9322 - assert Series.str is strings.StringMethods - assert isinstance(Series([""], dtype=any_string_dtype).str, strings.StringMethods) + assert Series.str is StringMethods + assert isinstance(Series([""], dtype=any_string_dtype).str, StringMethods) def test_api_mi_raises(): @@ -45,7 +45,7 @@ def test_api_per_dtype(index_or_series, dtype, any_skipna_inferred_dtype): ] if inferred_dtype in types_passing_constructor: # GH 6106 - assert isinstance(t.str, strings.StringMethods) + assert isinstance(t.str, StringMethods) else: # GH 9184, GH 23011, GH 23163 msg = "Can only use .str accessor with string values.*" @@ -138,7 +138,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype, request): s = Series(list("aabb"), dtype=any_string_dtype) s = s + " " + s c = s.astype("category") - assert isinstance(c.str, strings.StringMethods) + assert isinstance(c.str, StringMethods) method_name, args, kwargs = any_string_method From 40fbe97a2241ae3d5c95a81bedffd3aa8767a0a2 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 20 Dec 2022 15:56:12 -0800 Subject: [PATCH 10/12] Fix methods for inherited ArrowStringARray --- pandas/core/arrays/arrow/array.py | 43 ++++++++++++----------- pandas/core/arrays/string_arrow.py | 34 +++++++++++++----- pandas/core/groupby/ops.py | 11 ++++-- pandas/tests/extension/test_arrow.py | 19 ++++++++-- pandas/tests/strings/test_find_replace.py | 35 +++++++----------- 5 files changed, 85 insertions(+), 57 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index f9b5673a94c61..ca73cbf72610c 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1248,7 +1248,9 @@ def _str_pad( elif side == "right": pa_pad = pc.utf8_rpad elif side == "both": - pa_pad = pc.utf8_center + # https://github.com/apache/arrow/issues/15053 + # pa_pad = pc.utf8_center + return super()._str_pad(width, side, fillchar) else: raise ValueError( f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" @@ -1341,18 +1343,20 @@ def _str_fullmatch( pat = f"{pat}$" return self._str_match(pat, case, flags, na) - def _str_encode(self, encoding: str, errors: str = "strict"): - if errors != "strict" or encoding.lower() != "utf-8": - fallback_performancewarning() - return super()._str_encode(encoding, errors) - return type(self)(self._data.cast(pa.binary())) - def _str_find(self, sub: str, start: int = 0, end: int | None = None): - if start != 0 or end is not None: + if start != 0 and end is not None: slices = pc.utf8_slice_codeunits(self._data, start, stop=end) - else: + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: slices = self._data - return type(self)(pc.find_substring(slices, sub)) + result = pc.find_substring(slices, sub) + else: + fallback_performancewarning() + return super()._str_find(sub, start, end) + return type(self)(result) def _str_get(self, i: int): lengths = pc.utf8_length(self._data) @@ -1366,32 +1370,31 @@ def _str_get(self, i: int): start = i stop = i - 1 step = -1 - not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True).combine_chunks()) - selected = ( - pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) - .combine_chunks() - .drop_null() - ) + not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True)) + selected = pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) result = pa.array([None] * self._data.length(), type=self._data.type) - result = pc.replace_with_mask(result, not_out_of_bounds, selected) + result = pc.if_else(not_out_of_bounds, selected, result) return type(self)(result) def _str_join(self, sep: str): - if pa.types.is_list(self.dtype.pyarrow_dtype): + if isinstance(self._dtype, ArrowDtype) and pa.types.is_list( + self.dtype.pyarrow_dtype + ): + # Check ArrowDtype as ArrowString inherits and uses StringDtype return type(self)(pc.binary_join(self._data, sep)) else: return super()._str_join(sep) def _str_partition(self, sep: str, expand: bool): result = super()._str_partition(sep, expand) - if expand: + if expand and isinstance(result, type(self)): # StringMethods._wrap_result needs numpy-like nested object result = result.to_numpy(dtype=object) return result def _str_rpartition(self, sep: str, expand: bool): result = super()._str_rpartition(sep, expand) - if expand: + if expand and isinstance(result, type(self)): # StringMethods._wrap_result needs numpy-like nested object result = result.to_numpy(dtype=object) return result diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 97262b1f4bb21..0a901f58f396c 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -276,21 +276,37 @@ def _str_contains( return super()._str_contains(pat, case, flags, na, regex) if regex: - if case is False: - fallback_performancewarning() - return super()._str_contains(pat, case, flags, na, regex) - else: - result = pc.match_substring_regex(self._data, pat) + result = pc.match_substring_regex(self._data, pat, ignore_case=not case) else: - if case: - result = pc.match_substring(self._data, pat) - else: - result = pc.match_substring(pc.utf8_upper(self._data), pat.upper()) + result = pc.match_substring(self._data, pat, ignore_case=not case) result = BooleanDtype().__from_arrow__(result) if not isna(na): result[isna(result)] = bool(na) return result + def _str_count(self, pat: str, flags: int = 0): + if flags: + fallback_performancewarning() + return super()._str_count(pat, flags) + result = pc.count_substring_regex(self._data, pat) + type_mapping = {pa.int32(): Int64Dtype(), pa.int64(): Int64Dtype()} + return result.to_pandas(types_mapper=type_mapping.get) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 and end is not None: + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: + slices = self._data + result = pc.find_substring(slices, sub) + else: + return super()._str_find(sub, start, end) + type_mapping = {pa.int32(): Int64Dtype(), pa.int64(): Int64Dtype()} + return result.to_pandas(types_mapper=type_mapping.get) + def _str_startswith(self, pat: str, na=None): pat = f"^{re.escape(pat)}" return self._str_contains(pat, na=na, regex=True) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index c20fe34a178f5..27a4e47da9dbc 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -72,6 +72,7 @@ PeriodArray, TimedeltaArray, ) +from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.arrays.masked import ( BaseMaskedArray, BaseMaskedDtype, @@ -385,7 +386,9 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray: # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents npvalues = values._ndarray.view("M8[ns]") - elif isinstance(values.dtype, StringDtype): + elif isinstance(values.dtype, StringDtype) or ( + isinstance(values.dtype, ArrowDtype) and values.dtype.kind == "U" + ): # StringArray npvalues = values.to_numpy(object, na_value=np.nan) else: @@ -401,9 +404,11 @@ def _reconstruct_ea_result( """ Construct an ExtensionArray result from an ndarray result. """ - dtype: BaseMaskedDtype | StringDtype + dtype: BaseMaskedDtype | StringDtype | ArrowDtype - if isinstance(values.dtype, StringDtype): + if isinstance(values.dtype, StringDtype) or ( + isinstance(values.dtype, ArrowDtype) and values.dtype.kind == "U" + ): dtype = values.dtype string_array_cls = dtype.construct_array_type() return string_array_cls._from_sequence(res_values, dtype=dtype) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index c76896e58d1ce..1d8aca2105211 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -38,7 +38,10 @@ import pandas as pd import pandas._testing as tm -from pandas.api.types import is_bool_dtype +from pandas.api.types import ( + is_bool_dtype, + is_string_dtype, +) from pandas.tests.extension import base pa = pytest.importorskip("pyarrow", minversion="1.0.1") @@ -510,7 +513,11 @@ def test_groupby_extension_apply( def test_in_numeric_groupby(self, data_for_grouping, request): pa_dtype = data_for_grouping.dtype.pyarrow_dtype - if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_string(pa_dtype) + ): request.node.add_marker( pytest.mark.xfail( reason="ArrowExtensionArray doesn't support .sum() yet.", @@ -619,7 +626,6 @@ def test_get_common_dtype(self, dtype, request): and (pa_dtype.unit != "ns" or pa_dtype.tz is not None) ) or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns") - or pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) ): request.node.add_marker( @@ -632,6 +638,13 @@ def test_get_common_dtype(self, dtype, request): ) super().test_get_common_dtype(dtype) + def test_is_not_string_type(self, dtype): + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_string(pa_dtype): + assert is_string_dtype(dtype) + else: + super().test_is_not_string_type(dtype) + class TestBaseIndex(base.BaseIndexTests): pass diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 6f6acb7a996b2..1fd5c00bf3764 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -53,10 +53,7 @@ def test_contains(any_string_dtype): np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object), dtype=any_string_dtype, ) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = values.str.contains("FOO|mmm", case=False) + result = values.str.contains("FOO|mmm", case=False) expected = Series(np.array([True, False, True, True]), dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -172,10 +169,7 @@ def test_contains_moar(any_string_dtype): ) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = s.str.contains("a", case=False) + result = s.str.contains("a", case=False) expected = Series( [True, False, False, True, True, False, np.nan, True, False, True], dtype=expected_dtype, @@ -196,10 +190,7 @@ def test_contains_moar(any_string_dtype): ) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = s.str.contains("ba", case=False) + result = s.str.contains("ba", case=False) expected = Series( [False, False, False, True, True, False, np.nan, True, False, False], dtype=expected_dtype, @@ -715,10 +706,7 @@ def test_match_na_kwarg(any_string_dtype): def test_match_case_kwarg(any_string_dtype): values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = values.str.match("ab", case=False) + result = values.str.match("ab", case=False) expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean" expected = Series([True, True, True, True], dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -761,10 +749,7 @@ def test_fullmatch_case_kwarg(any_string_dtype): expected = Series([True, True, False, False], dtype=expected_dtype) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): - result = ser.str.fullmatch("ab", case=False) + result = ser.str.fullmatch("ab", case=False) tm.assert_series_equal(result, expected) with tm.maybe_produces_warning( @@ -842,7 +827,10 @@ def test_find(any_string_dtype): expected = np.array([v.rfind("EF") for v in np.array(ser)], dtype=np.int64) tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) - result = ser.str.find("EF", 3) + with tm.maybe_produces_warning( + PerformanceWarning, any_string_dtype == "string[pyarrow]" + ): + result = ser.str.find("EF", 3) expected = Series([4, 3, 7, 4, -1], dtype=expected_dtype) tm.assert_series_equal(result, expected) expected = np.array([v.find("EF", 3) for v in np.array(ser)], dtype=np.int64) @@ -890,7 +878,10 @@ def test_find_nan(any_string_dtype): expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) tm.assert_series_equal(result, expected) - result = ser.str.find("EF", 3) + with tm.maybe_produces_warning( + PerformanceWarning, any_string_dtype == "string[pyarrow]" + ): + result = ser.str.find("EF", 3) expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) tm.assert_series_equal(result, expected) From 6c0009560af3456b75898df8429dee77a5e00c5a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 21 Dec 2022 11:42:36 -0800 Subject: [PATCH 11/12] Fix some library compat --- pandas/core/arrays/arrow/array.py | 16 +++++++++------- pandas/core/arrays/string_arrow.py | 17 ++++++++++++++--- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index ca73cbf72610c..f6e8d41b13c24 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1255,7 +1255,7 @@ def _str_pad( raise ValueError( f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" ) - return type(self)(pa_pad(self._data, width, padding=fillchar)) + return type(self)(pa_pad(self._data, width=width, padding=fillchar)) def _str_map( self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True @@ -1291,13 +1291,13 @@ def _str_contains( return type(self)(result) def _str_startswith(self, pat: str, na=None): - result = pc.starts_with(self._data, pat) + result = pc.starts_with(self._data, pattern=pat) if na is not None: result = result.fill_null(na) return type(self)(result) def _str_endswith(self, pat: str, na=None): - result = pc.ends_with(self._data, pat) + result = pc.ends_with(self._data, pattern=pat) if na is not None: result = result.fill_null(na) return type(self)(result) @@ -1371,7 +1371,9 @@ def _str_get(self, i: int): stop = i - 1 step = -1 not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True)) - selected = pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) + selected = pc.utf8_slice_codeunits( + self._data, start=start, stop=stop, step=step + ) result = pa.array([None] * self._data.length(), type=self._data.type) result = pc.if_else(not_out_of_bounds, selected, result) return type(self)(result) @@ -1411,7 +1413,7 @@ def _str_slice( if step is None: step = 1 return type(self)( - pc.utf8_slice_codeunits(self._data, start, stop=stop, step=step) + pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step) ) def _str_slice_replace( @@ -1495,13 +1497,13 @@ def _str_rstrip(self, to_strip=None): # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed # def _str_removeprefix(self, prefix: str) -> Series: - # starts_with = pc.starts_with(self._data, prefix) + # starts_with = pc.starts_with(self._data, pattern=prefix) # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) # result = pc.if_else(starts_with, removed, self._data) # return type(self)(result) def _str_removesuffix(self, suffix: str) -> Series: - ends_with = pc.ends_with(self._data, suffix) + ends_with = pc.ends_with(self._data, pattern=suffix) removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) result = pc.if_else(ends_with, removed, self._data) return type(self)(result) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 0a901f58f396c..58a793b222c68 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -15,7 +15,10 @@ Scalar, npt, ) -from pandas.compat import pa_version_under6p0 +from pandas.compat import ( + pa_version_under6p0, + pa_version_under8p0, +) from pandas.core.dtypes.common import ( is_bool_dtype, @@ -290,7 +293,11 @@ def _str_count(self, pat: str, flags: int = 0): return super()._str_count(pat, flags) result = pc.count_substring_regex(self._data, pat) type_mapping = {pa.int32(): Int64Dtype(), pa.int64(): Int64Dtype()} - return result.to_pandas(types_mapper=type_mapping.get) + pd_result = result.to_pandas(types_mapper=type_mapping.get) + if pa_version_under8p0: + # Bug in pyarrow not respecting type_mapper + pd_result = pd_result.astype(Int64Dtype()) + return pd_result def _str_find(self, sub: str, start: int = 0, end: int | None = None): if start != 0 and end is not None: @@ -305,7 +312,11 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): else: return super()._str_find(sub, start, end) type_mapping = {pa.int32(): Int64Dtype(), pa.int64(): Int64Dtype()} - return result.to_pandas(types_mapper=type_mapping.get) + pd_result = result.to_pandas(types_mapper=type_mapping.get) + if pa_version_under8p0: + # Bug in pyarrow not respecting type_mapper + pd_result = pd_result.astype(Int64Dtype()) + return pd_result def _str_startswith(self, pat: str, na=None): pat = f"^{re.escape(pat)}" From ffee70a1ffad30deefe40d332f9966f439bc69bd Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 28 Dec 2022 11:47:30 -0800 Subject: [PATCH 12/12] Fix one typing --- pandas/core/arrays/arrow/array.py | 2 +- pandas/tests/strings/conftest.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a9414a6711678..280e367e3263c 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -208,7 +208,7 @@ def __init__(self, values: pa.Array | pa.ChunkedArray) -> None: f"Unsupported type '{type(values)}' for ArrowExtensionArray" ) self._dtype = ArrowDtype(self._data.type) - assert self._dtype.na_value is self._str_na_value + assert self._dtype.na_value is self._str_na_value # type: ignore[has-type] @classmethod def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): diff --git a/pandas/tests/strings/conftest.py b/pandas/tests/strings/conftest.py index c1d58babcf9aa..46df7c00a03a1 100644 --- a/pandas/tests/strings/conftest.py +++ b/pandas/tests/strings/conftest.py @@ -1,8 +1,6 @@ import numpy as np import pytest -from pandas import Series - # subset of the full set from pandas/conftest.py _any_allowed_skipna_inferred_dtype = [ ("string", ["a", np.nan, "c"]),