diff --git a/asv_bench/benchmarks/strings.py b/asv_bench/benchmarks/strings.py index 2e109e59c1c6d..700393cc72492 100644 --- a/asv_bench/benchmarks/strings.py +++ b/asv_bench/benchmarks/strings.py @@ -235,13 +235,31 @@ class Split(Dtypes): def setup(self, dtype, expand): super().setup(dtype) - self.s = self.s.str.join("--") + self.s = self.s.str.join("-") def time_split(self, dtype, expand): - self.s.str.split("--", expand=expand) + self.s.str.split("-", expand=expand) def time_rsplit(self, dtype, expand): - self.s.str.rsplit("--", expand=expand) + self.s.str.rsplit("-", expand=expand) + + +class SplitPattern(Dtypes): + + params = (Dtypes.params, [None, "--"]) + param_names = ["dtype", "pat"] + + def setup(self, dtype, pat): + super().setup(dtype) + if pat is None: + pat = " " + self.s = self.s.str.join(pat) + + def time_split(self, dtype, pat): + self.s.str.split(pat) + + def time_rsplit(self, dtype, pat): + self.s.str.rsplit(pat) class Extract(Dtypes): diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index d5ee28eb7017e..3b20df5fdf82b 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -904,6 +904,29 @@ def _str_lower(self): def _str_upper(self): return type(self)(pc.utf8_upper(self._data)) + def _str_split(self, pat=None, n=-1, expand=False): + if pa_version_under3p0 or (pat is not None and len(pat) > 1): + return super()._str_split(pat=pat, n=n, expand=expand) + + if n is None or n == 0: + n = -1 + + if pat is None: + result = pc.utf8_split_whitespace(self._data, max_splits=n) + else: + result = pc.split_pattern(self._data, pattern=pat, max_splits=n) + + mask = np.array(result.is_null()) + result = np.array(result) + result = lib.map_infer_mask( + result, + lambda x: x.tolist(), + mask.view(np.uint8), + na_value=self.dtype.na_value, + dtype=np.dtype(object), + ) + return result + def _str_strip(self, to_strip=None): if pa_version_under4p0: return super()._str_strip(to_strip) diff --git a/pandas/tests/strings/test_split_partition.py b/pandas/tests/strings/test_split_partition.py index f3f5acd0d2f1c..b80656f52dbf0 100644 --- a/pandas/tests/strings/test_split_partition.py +++ b/pandas/tests/strings/test_split_partition.py @@ -144,11 +144,14 @@ def test_split_blank_string(any_string_dtype): def test_split_noargs(any_string_dtype): # #1859 + expected = ["Travis", "Oliphant"] + s = Series(["Wes McKinney", "Travis Oliphant"], dtype=any_string_dtype) result = s.str.split() - expected = ["Travis", "Oliphant"] assert result[1] == expected - result = s.str.rsplit() + + s = Series(["Wes McKinney", "Travis Oliphant", np.nan], dtype=any_string_dtype) + result = s.str.split() assert result[1] == expected