Skip to content

WARN: PerformanceWarning for non-pyarrow fallback #46732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Other enhancements
- :meth:`pd.concat` now raises when ``levels`` is given but ``keys`` is None (:issue:`46653`)
- :meth:`pd.concat` now raises when ``levels`` contains duplicate values (:issue:`46653`)
- Added ``numeric_only`` argument to :meth:`DataFrame.corr`, :meth:`DataFrame.corrwith`, and :meth:`DataFrame.cov` (:issue:`46560`)
- A :class:`errors.PerformanceWarning` is now thrown when using ``string[pyarrow]`` dtype with methods that don't dispatch to ``pyarrow.compute`` methods (:issue:`42613`)

.. ---------------------------------------------------------------------------
.. _whatsnew_150.notable_bug_fixes:
Expand Down
15 changes: 15 additions & 0 deletions pandas/core/arrays/arrow/_arrow_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
from __future__ import annotations

import json
import warnings

import numpy as np
import pyarrow

from pandas.errors import PerformanceWarning
from pandas.util._exceptions import find_stack_level

from pandas.core.arrays.interval import VALID_CLOSED


def fallback_performancewarning(version: str | None = None):
"""
Raise a PerformanceWarning for falling back to ExtensionArray's
non-pyarrow method
"""
msg = "Falling back on a non-pyarrow code path which may decrease performance."
if version is not None:
msg += f" Upgrade to pyarrow >={version} to possibly suppress this warning."
warnings.warn(msg, PerformanceWarning, stacklevel=find_stack_level())


def pyarrow_array_to_numpy_and_mask(arr, dtype: np.dtype):
"""
Convert a primitive pyarrow.Array to a numpy array and boolean mask based
Expand Down
15 changes: 15 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
import pyarrow as pa
import pyarrow.compute as pc

from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning

ARROW_CMP_FUNCS = {
"eq": pc.equal,
"ne": pc.not_equal,
Expand Down Expand Up @@ -331,6 +333,7 @@ def _maybe_convert_setitem_value(self, value):

def isin(self, values):
if pa_version_under2p0:
fallback_performancewarning(version="2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check that we are producing this in tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing this now with tm.assert_produces_warning

return super().isin(values)

value_set = [
Expand Down Expand Up @@ -437,10 +440,12 @@ def _str_map(

def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex: bool = True):
if flags:
fallback_performancewarning()
return super()._str_contains(pat, case, flags, na, regex)

if regex:
if pa_version_under4p0 or case is False:
fallback_performancewarning(version="4")
return super()._str_contains(pat, case, flags, na, regex)
else:
result = pc.match_substring_regex(self._data, pat)
Expand All @@ -456,13 +461,15 @@ def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex: bool = True):

def _str_startswith(self, pat: str, na=None):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_startswith(pat, na)

pat = "^" + re.escape(pat)
return self._str_contains(pat, na=na, regex=True)

def _str_endswith(self, pat: str, na=None):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_endswith(pat, na)

pat = re.escape(pat) + "$"
Expand All @@ -484,6 +491,7 @@ def _str_replace(
or not case
or flags
):
fallback_performancewarning(version="4")
return super()._str_replace(pat, repl, n, case, flags, regex)

func = pc.replace_substring_regex if regex else pc.replace_substring
Expand All @@ -494,6 +502,7 @@ def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_match(pat, case, flags, na)

if not pat.startswith("^"):
Expand All @@ -504,6 +513,7 @@ def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_fullmatch(pat, case, flags, na)

if not pat.endswith("$") or pat.endswith("//$"):
Expand Down Expand Up @@ -536,6 +546,7 @@ def _str_isnumeric(self):

def _str_isspace(self):
if pa_version_under2p0:
fallback_performancewarning(version="2")
return super()._str_isspace()

result = pc.utf8_is_space(self._data)
Expand All @@ -551,6 +562,7 @@ def _str_isupper(self):

def _str_len(self):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_len()

result = pc.utf8_length(self._data)
Expand All @@ -564,6 +576,7 @@ def _str_upper(self):

def _str_strip(self, to_strip=None):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_strip(to_strip)

if to_strip is None:
Expand All @@ -574,6 +587,7 @@ def _str_strip(self, to_strip=None):

def _str_lstrip(self, to_strip=None):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_lstrip(to_strip)

if to_strip is None:
Expand All @@ -584,6 +598,7 @@ def _str_lstrip(self, to_strip=None):

def _str_rstrip(self, to_strip=None):
if pa_version_under4p0:
fallback_performancewarning(version="4")
return super()._str_rstrip(to_strip)

if to_strip is None:
Expand Down
23 changes: 19 additions & 4 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
This module tests the functionality of StringArray and ArrowStringArray.
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
"""
from contextlib import nullcontext

import numpy as np
import pytest

from pandas.compat import pa_version_under2p0
from pandas.errors import PerformanceWarning
import pandas.util._test_decorators as td

from pandas.core.dtypes.common import is_dtype_equal
Expand All @@ -14,6 +18,13 @@
from pandas.core.arrays.string_arrow import ArrowStringArray


def maybe_perf_warn(using_pyarrow):
if using_pyarrow:
return tm.assert_produces_warning(PerformanceWarning, match="Falling back")
else:
return nullcontext()


@pytest.fixture
def dtype(string_storage):
return pd.StringDtype(storage=string_storage)
Expand Down Expand Up @@ -557,18 +568,22 @@ def test_to_numpy_na_value(dtype, nulls_fixture):
def test_isin(dtype, fixed_now_ts):
s = pd.Series(["a", "b", None], dtype=dtype)

result = s.isin(["a", "c"])
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
result = s.isin(["a", "c"])
expected = pd.Series([True, False, False])
tm.assert_series_equal(result, expected)

result = s.isin(["a", pd.NA])
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
result = s.isin(["a", pd.NA])
expected = pd.Series([True, False, True])
tm.assert_series_equal(result, expected)

result = s.isin([])
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
result = s.isin([])
expected = pd.Series([False, False, False])
tm.assert_series_equal(result, expected)

result = s.isin(["a", fixed_now_ts])
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
result = s.isin(["a", fixed_now_ts])
expected = pd.Series([True, False, False])
tm.assert_series_equal(result, expected)
Loading