Skip to content

ENH: str extract with default value #38003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Other enhancements
- Added :meth:`MultiIndex.dtypes` (:issue:`37062`)
- Added ``end`` and ``end_day`` options for ``origin`` in :meth:`DataFrame.resample` (:issue:`37804`)
- Improve error message when ``usecols`` and ``names`` do not match for :func:`read_csv` and ``engine="c"`` (:issue:`29042`)
- :meth:`Series.str.extract` accepts the ``fill_value`` argument to fill non matching values (:issue:`38001`)

.. ---------------------------------------------------------------------------

Expand Down
29 changes: 18 additions & 11 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2218,7 +2218,7 @@ def findall(self, pat, flags=0):
return self._wrap_result(result, returns_string=False)

@forbid_nonstring_types(["bytes"])
def extract(self, pat, flags=0, expand=True):
def extract(self, pat, flags=0, expand=True, fill_value=None):
r"""
Extract capture groups in the regex `pat` as columns in a DataFrame.

Expand All @@ -2237,6 +2237,8 @@ def extract(self, pat, flags=0, expand=True):
If True, return DataFrame with one column per capture group.
If False, return a Series/Index if there is one capture group
or DataFrame if there are multiple capture groups.
fill_value: str, default None
Value to use as default value when the regex does not match.

Returns
-------
Expand Down Expand Up @@ -2300,7 +2302,7 @@ def extract(self, pat, flags=0, expand=True):
dtype: object
"""
# TODO: dispatch
return str_extract(self, pat, flags, expand=expand)
return str_extract(self, pat, flags, expand=expand, fill_value=fill_value)

@forbid_nonstring_types(["bytes"])
def extractall(self, pat, flags=0):
Expand Down Expand Up @@ -2950,18 +2952,21 @@ def cat_core(list_of_columns: List, sep: str):
return np.sum(arr_with_sep, axis=0)


def _groups_or_na_fun(regex):
def _groups_or_na_fun(regex, fill_value=None):
"""Used in both extract_noexpand and extract_frame"""
if regex.groups == 0:
raise ValueError("pattern contains no capture groups")
empty_row = [np.nan] * regex.groups
fill_value = [fill_value] * regex.groups

def f(x):
if not isinstance(x, str):
return empty_row
m = regex.search(x)
if m:
return [np.nan if item is None else item for item in m.groups()]
elif not m and fill_value:
return fill_value
else:
return empty_row

Expand All @@ -2987,7 +2992,7 @@ def _get_single_group_name(rx):
return None


def _str_extract_noexpand(arr, pat, flags=0):
def _str_extract_noexpand(arr, pat, flags=0, fill_value=None):
"""
Find groups in each string in the Series using passed regular
expression. This function is called from
Expand All @@ -2998,11 +3003,11 @@ def _str_extract_noexpand(arr, pat, flags=0):
from pandas import DataFrame, array

regex = re.compile(pat, flags=flags)
groups_or_na = _groups_or_na_fun(regex)
groups_or_na = _groups_or_na_fun(regex, fill_value)
result_dtype = _result_dtype(arr)

if regex.groups == 1:
result = np.array([groups_or_na(val)[0] for val in arr], dtype=object)
result = [groups_or_na(val)[0] for val in arr]
name = _get_single_group_name(regex)
# not dispatching, so we have to reconstruct here.
result = array(result, dtype=result_dtype)
Expand All @@ -3025,7 +3030,7 @@ def _str_extract_noexpand(arr, pat, flags=0):
return result, name


def _str_extract_frame(arr, pat, flags=0):
def _str_extract_frame(arr, pat, flags=0, fill_value=None):
"""
For each subject string in the Series, extract groups from the
first match of regular expression pat. This function is called from
Expand All @@ -3035,7 +3040,7 @@ def _str_extract_frame(arr, pat, flags=0):
from pandas import DataFrame

regex = re.compile(pat, flags=flags)
groups_or_na = _groups_or_na_fun(regex)
groups_or_na = _groups_or_na_fun(regex, fill_value=fill_value)
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
columns = [names.get(1 + i, i) for i in range(regex.groups)]

Expand All @@ -3054,14 +3059,16 @@ def _str_extract_frame(arr, pat, flags=0):
)


def str_extract(arr, pat, flags=0, expand=True):
def str_extract(arr, pat, flags=0, expand=True, fill_value=None):
if not isinstance(expand, bool):
raise ValueError("expand must be True or False")
if expand:
result = _str_extract_frame(arr._orig, pat, flags=flags)
result = _str_extract_frame(arr._orig, pat, flags=flags, fill_value=fill_value)
return result.__finalize__(arr._orig, method="str_extract")
else:
result, name = _str_extract_noexpand(arr._orig, pat, flags=flags)
result, name = _str_extract_noexpand(
arr._orig, pat, flags=flags, fill_value=fill_value
)
return arr._wrap_result(result, name=name, expand=expand)


Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3670,3 +3670,19 @@ def test_str_get_stringarray_multiple_nans():
result = s.str.get(2)
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"]))
tm.assert_series_equal(result, expected)


def test_str_extract_default_value_no_expand():
# GH 38001
df = DataFrame({"A": ["a84", "abcd", "99string", np.nan]})
result = df["A"].str.extract(r"(\d+)", expand=False, fill_value="missing")
expected = Series(["84", "missing", "99", np.nan], name="A")
tm.assert_series_equal(result, expected)


def test_str_extract_default_value_with_expand():
# GH 38001
df = DataFrame({"A": ["a84", "abcd", "99string", np.nan]})
result = df["A"].str.extract(r"(\d+)", expand=True, fill_value="missing")
expected = DataFrame({0: ["84", "missing", "99", np.nan]})
tm.assert_frame_equal(result, expected)