From 6d6868327b980e0fe5186d63367da23f885850f0 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 8 Jan 2024 20:34:52 -0500 Subject: [PATCH 01/14] fix find --- pandas/core/arrays/arrow/array.py | 30 +++++++++++++++++----------- pandas/tests/extension/test_arrow.py | 20 ++++++++++--------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 23b5c6385c13b..f7a1ca702891a 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2328,19 +2328,25 @@ def _str_fullmatch( return self._str_match(pat, case, flags, na) def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: - if start != 0 and end is not None: - slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) - result = pc.find_substring(slices, sub) - not_found = pc.equal(result, -1) - start_offset = max(0, start) - offset_result = pc.add(result, start_offset) - result = pc.if_else(not_found, result, offset_result) - elif start == 0 and end is None: - slices = self._pa_array - result = pc.find_substring(slices, sub) + if (start == 0 or start is None) and end is None: + result = pc.find_substring(self._pa_array, sub) else: - raise NotImplementedError( - f"find not implemented with {sub=}, {start=}, {end=}" + result = pc.find_substring(self._pa_array, sub) + length = pc.utf8_length(self._pa_array) + if start is None: + start = pa.scalar(0, result.type) + elif start < 0: + start = pc.add(start, length) + if end is None: + end = length + elif end < 0: + end = pc.add(end, length) + found = pc.not_equal(pa.scalar(-1, type=result.type), result) + found_in_bounds = pc.and_( + pc.greater_equal(result, start), pc.less(result, end) + ) + result = pc.if_else( + pc.and_(found, found_in_bounds), result, pa.scalar(-1, type=result.type) ) return type(self)(result) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 8d0bb85b2a01f..913a0f1c9a29e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1919,13 +1919,13 @@ def test_str_fullmatch(pat, case, na, exp): @pytest.mark.parametrize( - "sub, start, end, exp, exp_typ", - [["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()]], + "sub, start, end, exp", + [["ab", 0, None, [0, None]], ["bc", 1, 3, [1, None]], ["ab", 1, None, [-1, None]]], ) -def test_str_find(sub, start, end, exp, exp_typ): +def test_str_find(sub, start, end, exp): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) result = ser.str.find(sub, start=start, end=end) - expected = pd.Series(exp, dtype=ArrowDtype(exp_typ)) + expected = pd.Series(exp, dtype=ArrowDtype(pa.int32())) tm.assert_series_equal(result, expected) @@ -1933,14 +1933,16 @@ def test_str_find_negative_start(): # GH 56411 ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) result = ser.str.find(sub="b", start=-1000, end=3) - expected = pd.Series([1, None], dtype=ArrowDtype(pa.int64())) + expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32())) tm.assert_series_equal(result, expected) -def test_str_find_notimplemented(): - ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - with pytest.raises(NotImplementedError, match="find not implemented"): - ser.str.find("ab", start=1) +def test_str_find_negative_start_negative_end(): + # GH 56411 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=-6, end=-3) + expected = pd.Series([3, None], dtype=ArrowDtype(pa.int32())) + tm.assert_series_equal(result, expected) @pytest.mark.parametrize( From 533f54c7b960652764ad968f87fb4f372a7b7249 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 8 Jan 2024 21:12:13 -0500 Subject: [PATCH 02/14] gh reference --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/arrays/arrow/array.py | 27 ++++++++++++--------------- pandas/tests/extension/test_arrow.py | 19 ++++++++++++------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 0b04a1d313a6d..4a83f8003739a 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -803,6 +803,7 @@ Strings - Bug in :meth:`DataFrame.reindex` not matching :class:`Index` with ``string[pyarrow_numpy]`` dtype (:issue:`56106`) - Bug in :meth:`Index.str.cat` always casting result to object dtype (:issue:`56157`) - Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`) +- Bug in :meth:`Series.str.find` when ``start < 0`` and ``end < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56791`) - Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`) - Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`) - Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for :class:`ArrowDtype` with ``pyarrow.string`` dtype (:issue:`56579`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index f7a1ca702891a..860a2cbb44061 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2331,23 +2331,20 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: if (start == 0 or start is None) and end is None: result = pc.find_substring(self._pa_array, sub) else: - result = pc.find_substring(self._pa_array, sub) - length = pc.utf8_length(self._pa_array) if start is None: - start = pa.scalar(0, result.type) + start_offset = 0 + start = 0 elif start < 0: - start = pc.add(start, length) - if end is None: - end = length - elif end < 0: - end = pc.add(end, length) - found = pc.not_equal(pa.scalar(-1, type=result.type), result) - found_in_bounds = pc.and_( - pc.greater_equal(result, start), pc.less(result, end) - ) - result = pc.if_else( - pc.and_(found, found_in_bounds), result, pa.scalar(-1, type=result.type) - ) + start_offset = pc.add(start, pc.utf8_length(self._pa_array)) + start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) + else: + start_offset = start + slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, pa.scalar(-1, type=result.type)) + + offset_result = pc.add(result, start_offset) + result = pc.if_else(not_found, result, offset_result) return type(self)(result) def _str_join(self, sep: str) -> Self: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 913a0f1c9a29e..020747be608a9 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1919,13 +1919,18 @@ def test_str_fullmatch(pat, case, na, exp): @pytest.mark.parametrize( - "sub, start, end, exp", - [["ab", 0, None, [0, None]], ["bc", 1, 3, [1, None]], ["ab", 1, None, [-1, None]]], + "sub, start, end, exp, exp_type", + [ + ["ab", 0, None, [0, None], pa.int32()], + ["bc", 1, 3, [1, None], pa.int64()], + ["ab", 1, None, [-1, None], pa.int64()], + ["ab", -3, -3, [-1, None], pa.int64()], + ], ) -def test_str_find(sub, start, end, exp): +def test_str_find(sub, start, end, exp, exp_type): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) result = ser.str.find(sub, start=start, end=end) - expected = pd.Series(exp, dtype=ArrowDtype(pa.int32())) + expected = pd.Series(exp, dtype=ArrowDtype(exp_type)) tm.assert_series_equal(result, expected) @@ -1933,15 +1938,15 @@ def test_str_find_negative_start(): # GH 56411 ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) result = ser.str.find(sub="b", start=-1000, end=3) - expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32())) + expected = pd.Series([1, None], dtype=ArrowDtype(pa.int64())) tm.assert_series_equal(result, expected) def test_str_find_negative_start_negative_end(): - # GH 56411 + # GH 56791 ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) result = ser.str.find(sub="d", start=-6, end=-3) - expected = pd.Series([3, None], dtype=ArrowDtype(pa.int32())) + expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64())) tm.assert_series_equal(result, expected) From beaf5f02f3c990b61b25cd1542033e014f494315 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 8 Jan 2024 21:34:55 -0500 Subject: [PATCH 03/14] add test for Nones --- pandas/tests/extension/test_arrow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 020747be608a9..f769ddbea31b8 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1925,6 +1925,7 @@ def test_str_fullmatch(pat, case, na, exp): ["bc", 1, 3, [1, None], pa.int64()], ["ab", 1, None, [-1, None], pa.int64()], ["ab", -3, -3, [-1, None], pa.int64()], + ["ab", None, None, [0, None], pa.int32()], ], ) def test_str_find(sub, start, end, exp, exp_type): From 68763b1176c2ead93d78c7e17dfe8355fc37bd93 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 8 Jan 2024 22:07:14 -0500 Subject: [PATCH 04/14] fix min version compat --- pandas/tests/extension/test_arrow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index f769ddbea31b8..d541e1042215c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1923,9 +1923,8 @@ def test_str_fullmatch(pat, case, na, exp): [ ["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()], - ["ab", 1, None, [-1, None], pa.int64()], + ["ab", 1, 3, [-1, None], pa.int64()], ["ab", -3, -3, [-1, None], pa.int64()], - ["ab", None, None, [0, None], pa.int32()], ], ) def test_str_find(sub, start, end, exp, exp_type): From 52a6f528af850c4b76e7fa90d4661d6d246f042e Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 8 Jan 2024 22:12:28 -0500 Subject: [PATCH 05/14] restore test --- pandas/tests/extension/test_arrow.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index d541e1042215c..be722863a11a5 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1942,6 +1942,18 @@ def test_str_find_negative_start(): tm.assert_series_equal(result, expected) +def test_str_find_no_end(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + if pa_version_under13p0: + # https://github.com/apache/arrow/issues/36311 + with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): + ser.str.find("ab", start=1) + else: + result = ser.str.find("ab", start=1) + expected = pd.Series([-1, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) + + def test_str_find_negative_start_negative_end(): # GH 56791 ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) From 771e2f64af1fab2545f9b10f9a8a5f4aaab21951 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 9 Jan 2024 11:12:20 -0500 Subject: [PATCH 06/14] improve test cases --- pandas/tests/extension/test_arrow.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index be722863a11a5..7248ef51c0bc5 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1962,6 +1962,27 @@ def test_str_find_negative_start_negative_end(): tm.assert_series_equal(result, expected) +def test_str_find_large_start(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + if pa_version_under13p0: + # https://github.com/apache/arrow/issues/36311 + with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): + ser.str.find(sub="d", start=16) + else: + result = ser.str.find(sub="d", start=16) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_str_find_negative_start_negative_end_no_match(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=-3, end=-6) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( "i, exp", [ From 0de96cb3a0a30ebfdce73238e8b9ce235e782cc7 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 9 Jan 2024 15:27:39 -0500 Subject: [PATCH 07/14] fix empty string --- pandas/core/arrays/arrow/array.py | 20 ++++++++++++++++---- pandas/tests/extension/test_arrow.py | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 860a2cbb44061..3a29ed3842a35 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2331,20 +2331,32 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: if (start == 0 or start is None) and end is None: result = pc.find_substring(self._pa_array, sub) else: + length = pc.utf8_length(self._pa_array) if start is None: start_offset = 0 start = 0 elif start < 0: - start_offset = pc.add(start, pc.utf8_length(self._pa_array)) + start_offset = pc.add(start, length) start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) else: start_offset = start slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) result = pc.find_substring(slices, sub) - not_found = pc.equal(result, pa.scalar(-1, type=result.type)) - + found = pc.not_equal(result, pa.scalar(-1, type=result.type)) + if end is None: + end = length + elif end < 0: + end = pc.add(end, length) + end = pc.if_else(pc.less(end, 0), 0, end) + found = pc.and_( + found, + pc.and_( + pc.less_equal(start_offset, end), + pc.less_equal(start_offset, length), + ), + ) offset_result = pc.add(result, start_offset) - result = pc.if_else(not_found, result, offset_result) + result = pc.if_else(found, offset_result, -1) return type(self)(result) def _str_join(self, sep: str) -> Self: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 7248ef51c0bc5..e7aab3497d7d3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -23,6 +23,7 @@ BytesIO, StringIO, ) +from itertools import combinations import operator import pickle import re @@ -1975,6 +1976,27 @@ def test_str_find_large_start(): tm.assert_series_equal(result, expected) +def _get_all_substrings(string): + length = len(string) + 1 + return [string[x:y] for x, y in combinations(range(length), r=2)] + + +@pytest.mark.xfail( + pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" +) +def test_str_find_e2e(): + string = "abcdefgh" + s = pd.Series([string], dtype=ArrowDtype(pa.string())) + substrings = _get_all_substrings(string) + ["", "az", "abce"] + offsets = list(range(-15, 15)) + [None] + for start in offsets: + for end in offsets: + for sub in substrings: + result = s.str.find(sub, start, end) + expected = pd.Series([string.find(sub, start, end)], dtype=result.dtype) + tm.assert_series_equal(result, expected) + + def test_str_find_negative_start_negative_end_no_match(): # GH 56791 ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) From f5416e32278d5b416f65a748a45f52346e748231 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 9 Jan 2024 20:54:02 -0500 Subject: [PATCH 08/14] inline --- pandas/tests/extension/test_arrow.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index e7aab3497d7d3..c53503d24bdb7 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1976,19 +1976,18 @@ def test_str_find_large_start(): tm.assert_series_equal(result, expected) -def _get_all_substrings(string): - length = len(string) + 1 - return [string[x:y] for x, y in combinations(range(length), r=2)] - - @pytest.mark.xfail( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) def test_str_find_e2e(): - string = "abcdefgh" + string = "abcaadef" s = pd.Series([string], dtype=ArrowDtype(pa.string())) - substrings = _get_all_substrings(string) + ["", "az", "abce"] offsets = list(range(-15, 15)) + [None] + substrings = [string[x:y] for x, y in combinations(range(len(string) + 1), r=2)] + [ + "", + "az", + "abce", + ] for start in offsets: for end in offsets: for sub in substrings: From 06105ce805ccc57813bd08331fb776f1c6fe84e3 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 10 Jan 2024 18:57:01 -0500 Subject: [PATCH 09/14] improve tests --- pandas/tests/extension/test_arrow.py | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index c53503d24bdb7..2d562753c4789 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1979,21 +1979,26 @@ def test_str_find_large_start(): @pytest.mark.xfail( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) -def test_str_find_e2e(): - string = "abcaadef" - s = pd.Series([string], dtype=ArrowDtype(pa.string())) - offsets = list(range(-15, 15)) + [None] - substrings = [string[x:y] for x, y in combinations(range(len(string) + 1), r=2)] + [ +@pytest.mark.parametrize("start", list(range(-15, 15)) + [None]) +@pytest.mark.parametrize("end", list(range(-15, 15)) + [None]) +@pytest.mark.parametrize( + "sub", + ["abcaadef"[x:y] for x, y in combinations(range(len("abcaadef") + 1), r=2)] + + [ "", "az", "abce", - ] - for start in offsets: - for end in offsets: - for sub in substrings: - result = s.str.find(sub, start, end) - expected = pd.Series([string.find(sub, start, end)], dtype=result.dtype) - tm.assert_series_equal(result, expected) + ], +) +def test_str_find_e2e(start, end, sub): + s = pd.Series( + ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], + dtype=ArrowDtype(pa.string()), + ) + object_series = s.astype(pd.StringDtype()) + result = s.str.find(sub, start, end) + expected = object_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result, expected) def test_str_find_negative_start_negative_end_no_match(): From 7fa21eb24682ae587a0b3033942fbe1247f98921 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 10 Jan 2024 20:31:04 -0500 Subject: [PATCH 10/14] fix --- pandas/tests/extension/test_arrow.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 2d562753c4789..a6737d7999a6e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1979,8 +1979,6 @@ def test_str_find_large_start(): @pytest.mark.xfail( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) -@pytest.mark.parametrize("start", list(range(-15, 15)) + [None]) -@pytest.mark.parametrize("end", list(range(-15, 15)) + [None]) @pytest.mark.parametrize( "sub", ["abcaadef"[x:y] for x, y in combinations(range(len("abcaadef") + 1), r=2)] @@ -1990,15 +1988,18 @@ def test_str_find_large_start(): "abce", ], ) -def test_str_find_e2e(start, end, sub): +def test_str_find_e2e(sub): s = pd.Series( ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], dtype=ArrowDtype(pa.string()), ) object_series = s.astype(pd.StringDtype()) - result = s.str.find(sub, start, end) - expected = object_series.str.find(sub, start, end).astype(result.dtype) - tm.assert_series_equal(result, expected) + offsets = list(range(-15, 15)) + [None] + for start in offsets: + for end in offsets: + result = s.str.find(sub, start, end) + expected = object_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result, expected) def test_str_find_negative_start_negative_end_no_match(): From 8e27d85647921d2db4b49264cfbf1cec90868713 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 10 Jan 2024 20:35:37 -0500 Subject: [PATCH 11/14] Revert "fix" This reverts commit 7fa21eb24682ae587a0b3033942fbe1247f98921. --- pandas/tests/extension/test_arrow.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a6737d7999a6e..2d562753c4789 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1979,6 +1979,8 @@ def test_str_find_large_start(): @pytest.mark.xfail( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) +@pytest.mark.parametrize("start", list(range(-15, 15)) + [None]) +@pytest.mark.parametrize("end", list(range(-15, 15)) + [None]) @pytest.mark.parametrize( "sub", ["abcaadef"[x:y] for x, y in combinations(range(len("abcaadef") + 1), r=2)] @@ -1988,18 +1990,15 @@ def test_str_find_large_start(): "abce", ], ) -def test_str_find_e2e(sub): +def test_str_find_e2e(start, end, sub): s = pd.Series( ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], dtype=ArrowDtype(pa.string()), ) object_series = s.astype(pd.StringDtype()) - offsets = list(range(-15, 15)) + [None] - for start in offsets: - for end in offsets: - result = s.str.find(sub, start, end) - expected = object_series.str.find(sub, start, end).astype(result.dtype) - tm.assert_series_equal(result, expected) + result = s.str.find(sub, start, end) + expected = object_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result, expected) def test_str_find_negative_start_negative_end_no_match(): From 403bc4f5889ee8376021eacd285b3525fc031249 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 10 Jan 2024 20:36:31 -0500 Subject: [PATCH 12/14] fix --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 2d562753c4789..3b5ea67b2f2ca 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1976,7 +1976,7 @@ def test_str_find_large_start(): tm.assert_series_equal(result, expected) -@pytest.mark.xfail( +@pytest.mark.skipif( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) @pytest.mark.parametrize("start", list(range(-15, 15)) + [None]) From ef9a79161c78a224a3f1f1bb3ac03d9cdfc32cca Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Sat, 20 Jan 2024 13:37:33 -0500 Subject: [PATCH 13/14] merge --- doc/source/whatsnew/v2.2.0.rst | 1 - pandas/core/arrays/arrow/array.py | 16 ++++------------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 7f1624001c8d3..d9ab0452c8334 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -823,7 +823,6 @@ Strings - Bug in :meth:`DataFrame.reindex` not matching :class:`Index` with ``string[pyarrow_numpy]`` dtype (:issue:`56106`) - Bug in :meth:`Index.str.cat` always casting result to object dtype (:issue:`56157`) - Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`) -- Bug in :meth:`Series.str.find` when ``start < 0`` and ``end < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56791`) - Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`) - Bug in :meth:`Series.str.fullmatch` when ``dtype=pandas.ArrowDtype(pyarrow.string()))`` allows partial matches when regex ends in literal //$ (:issue:`56652`) - Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 8478e82b69db9..8f8881ffdc40d 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2371,6 +2371,10 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: if (start == 0 or start is None) and end is None: result = pc.find_substring(self._pa_array, sub) else: + if sub == "": + # GH 56792 + result = self._apply_elementwise(lambda val: val.find(sub, start, end)) + return type(self)(pa.chunked_array(result)) length = pc.utf8_length(self._pa_array) if start is None: start_offset = 0 @@ -2383,18 +2387,6 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) result = pc.find_substring(slices, sub) found = pc.not_equal(result, pa.scalar(-1, type=result.type)) - if end is None: - end = length - elif end < 0: - end = pc.add(end, length) - end = pc.if_else(pc.less(end, 0), 0, end) - found = pc.and_( - found, - pc.and_( - pc.less_equal(start_offset, end), - pc.less_equal(start_offset, length), - ), - ) offset_result = pc.add(result, start_offset) result = pc.if_else(found, offset_result, -1) return type(self)(result) From 2cf0fbfa857865858403d34ee936091caa22cf7f Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 22 Jan 2024 09:09:25 -0500 Subject: [PATCH 14/14] inline --- pandas/core/arrays/arrow/array.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 8f8881ffdc40d..31932579d207b 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2375,12 +2375,11 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: # GH 56792 result = self._apply_elementwise(lambda val: val.find(sub, start, end)) return type(self)(pa.chunked_array(result)) - length = pc.utf8_length(self._pa_array) if start is None: start_offset = 0 start = 0 elif start < 0: - start_offset = pc.add(start, length) + start_offset = pc.add(start, pc.utf8_length(self._pa_array)) start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) else: start_offset = start