Skip to content

BUG: preserve categorical & sparse types when grouping / pivot & preserve dtypes on ufuncs #26550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,68 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
s
s.str.startswith(b'a')

<<<<<<< HEAD
=======
.. _whatsnew_0250.api_breaking.ufuncs:

ufuncs on Extension Dtype
^^^^^^^^^^^^^^^^^^^^^^^^^

Operations with ``numpy`` ufuncs on DataFrames with Extension Arrays, including Sparse Dtypes will now preserve the
resulting dtypes to same as the input dtype; previously this would coerce to a dense dtype. (:issue:`23743`)

.. ipython:: python

df = pd.DataFrame(
{'A': pd.Series([1, np.nan, 3],
dtype=pd.SparseDtype('float64', np.nan))})
df
df.dtypes

*Previous Behavior*:

.. code-block:: python

In [3]: np.sqrt(df).dtypes
Out[3]:
A float64
dtype: object

*New Behavior*:

.. ipython:: python

np.sqrt(df).dtypes

.. _whatsnew_0250.api_breaking.groupby_categorical:

Categorical dtypes are preserved during groupby
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)

.. ipython:: python

df = pd.DataFrame(
{'payload': [-1, -2, -1, -2],
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
df
df.dtypes

*Previous Behavior*:

.. code-block:: python

In [5]: df.groupby('payload').first().col.dtype
Out[5]: dtype('O')

*New Behavior*:

.. ipython:: python

df.groupby('payload').first().col.dtype


.. _whatsnew_0250.api_breaking.incompatible_index_unions:

Incompatible Index Type Unions
Expand Down
25 changes: 23 additions & 2 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
infer_dtype_from_scalar)
from pandas.core.dtypes.common import (
is_array_like, is_bool_dtype, is_datetime64_any_dtype, is_dtype_equal,
is_integer, is_object_dtype, is_scalar, is_string_dtype, pandas_dtype)
is_float_dtype, is_integer, is_integer_dtype, is_object_dtype, is_scalar,
is_string_dtype, pandas_dtype)
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.generic import (
ABCIndexClass, ABCSeries, ABCSparseArray, ABCSparseSeries)
Expand Down Expand Up @@ -1926,8 +1927,28 @@ def make_sparse(arr, kind='block', fill_value=None, dtype=None, copy=False):

index = _make_index(length, indices, kind)
sparsified_values = arr[mask]

if dtype is not None:
sparsified_values = astype_nansafe(sparsified_values, dtype=dtype)

# careful about casting here as we could easily specify a type that
# cannot hold the resulting values, e.g. integer when we have floats
# if this is not safe then convert the dtype; note that if there are
# nan's in the source array this will raise

# TODO: ideally this would be done by 'safe' casting in astype_nansafe
# but alas too many cases rely upon this working in the current way
# and casting='safe' doesn't really work in numpy properly
if is_integer_dtype(dtype) and is_float_dtype(sparsified_values.dtype):
result = astype_nansafe(
sparsified_values, dtype=dtype)
if np.allclose(result, sparsified_values, rtol=0):
return result, index, fill_value

dtype = find_common_type([dtype, sparsified_values.dtype])

sparsified_values = astype_nansafe(
sparsified_values, dtype=dtype)

# TODO: copy
return sparsified_values, index, fill_value

Expand Down
8 changes: 5 additions & 3 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def conv(r, dtype):
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]


def astype_nansafe(arr, dtype, copy=True, skipna=False):
def astype_nansafe(arr, dtype, copy=True, skipna=False, casting='unsafe'):
"""
Cast the elements of an array to a given dtype a nan-safe manner.

Expand All @@ -616,8 +616,10 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
copy : bool, default True
If False, a view will be attempted but may fail, if
e.g. the item sizes don't align.
skipna: bool, default False
skipna : bool, default False
Whether or not we should skip NaN when casting as a string-type.
casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}
optional, default 'unsafe'

Raises
------
Expand Down Expand Up @@ -703,7 +705,7 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):

if copy or is_object_dtype(arr) or is_object_dtype(dtype):
# Explicit copy, or required since NumPy can't view from / to object.
return arr.astype(dtype, copy=True)
return arr.astype(dtype, copy=True, casting=casting)

return arr.view(dtype)

Expand Down
48 changes: 47 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
import warnings
from textwrap import dedent
from typing import FrozenSet, List, Optional, Set, Type, Union
from typing import FrozenSet, List, Optional, Tuple, Set, Type, Union

import numpy as np
import numpy.ma as ma
Expand Down Expand Up @@ -2641,6 +2641,52 @@ def transpose(self, *args, **kwargs):

T = property(transpose)

# ----------------------------------------------------------------------
# Array Interface

# This is also set in IndexOpsMixin
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
__array_priority__ = 1000

def __array__(self, dtype=None):
return com.values_from_object(self)

def __array_wrap__(self, result: np.ndarray,
context: Optional[Tuple] = None) -> 'DataFrame':
"""
We are called post ufunc; reconstruct the original object and dtypes.

Parameters
----------
result : np.ndarray
context : tuple, optional

Returns
-------
DataFrame
"""

d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
result = self._constructor(result, **d)

# we try to cast extension array types back to the original
# TODO: this fails with duplicates, ugh
if self._data.any_extension_types:
result = result.astype(self.dtypes,
copy=False,
errors='ignore',
casting='same_kind')

return result.__finalize__(self)

# ideally we would define this to avoid the getattr checks, but
# is slower
# @property
# def __array_interface__(self):
# """ provide numpy array interface method """
# values = self.values
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)

# ----------------------------------------------------------------------
# Picklability

Expand Down
24 changes: 5 additions & 19 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,25 +1919,6 @@ def empty(self):
# ----------------------------------------------------------------------
# Array Interface

# This is also set in IndexOpsMixin
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
__array_priority__ = 1000

def __array__(self, dtype=None):
return com.values_from_object(self)

def __array_wrap__(self, result, context=None):
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
return self._constructor(result, **d).__finalize__(self)

# ideally we would define this to avoid the getattr checks, but
# is slower
# @property
# def __array_interface__(self):
# """ provide numpy array interface method """
# values = self.values
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)

def to_dense(self):
"""
Return dense representation of NDFrame (as opposed to sparse).
Expand Down Expand Up @@ -5693,6 +5674,11 @@ def astype(self, dtype, copy=True, errors='raise', **kwargs):
**kwargs)
return self._constructor(new_data).__finalize__(self)

if not results:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What case hits this? I'm not immediately seeing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty frames :->

if copy:
self = self.copy()
return self

# GH 19920: retain column metadata after concat
result = pd.concat(results, axis=1, copy=False)
result.columns = self.columns
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,

obj = self.obj[data.items[locs]]
s = groupby(obj, self.grouper)
result = s.aggregate(lambda x: alt(x, axis=self.axis))
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for? Not immediately obvious the link between this and the overall PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this handles blocks that return NotImplementedError and then cann't be aggregated, e.g. Categoricals with string categories, aggregating with mean for exampel (and numeric_only=False) is passed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example? I'm also having trouble seeing this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If those return NotImplementedError can we limit the scope of the catching to just that?

result = s.aggregate(lambda x: alt(x, axis=self.axis))
except Exception:
# we may have an exception in trying to aggregate
# continue and exclude the block
pass

finally:

dtype = block.values.dtype

# see if we can cast the block back to the original dtype
result = block._try_coerce_and_cast_result(result)
result = block._try_coerce_and_cast_result(result, dtype=dtype)
newb = block.make_block(result)

new_items.append(locs)
Expand Down
42 changes: 32 additions & 10 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
elif is_extension_array_dtype(dtype):
# The function can return something of any type, so check
# if the type is compatible with the calling EA.

# return the same type (Series) as our caller
try:
result = obj._values._from_sequence(result, dtype=dtype)
except Exception:
Expand Down Expand Up @@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
"""
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
try:
return self._cython_agg_general('mean', **kwargs)
return self._cython_agg_general(
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
except GroupByError:
raise
except Exception: # pragma: no cover
Expand All @@ -1179,7 +1182,11 @@ def median(self, **kwargs):
Median of values within each group.
"""
try:
return self._cython_agg_general('median', **kwargs)
return self._cython_agg_general(
'median',
alt=lambda x,
axis: Series(x).median(axis=axis, **kwargs),
**kwargs)
except GroupByError:
raise
except Exception: # pragma: no cover
Expand Down Expand Up @@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
nv.validate_groupby_func('var', args, kwargs)
if ddof == 1:
try:
return self._cython_agg_general('var', **kwargs)
return self._cython_agg_general(
'var',
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
**kwargs)
except Exception:
f = lambda x: x.var(ddof=ddof, **kwargs)
with _group_selection_context(self):
Expand Down Expand Up @@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
Series or DataFrame
Standard error of the mean of values within each group.
"""

return self.std(ddof=ddof) / np.sqrt(self.count())

@Substitution(name='groupby')
Expand All @@ -1290,7 +1299,7 @@ def _add_numeric_operations(cls):
"""

def groupby_function(name, alias, npfunc,
numeric_only=True, _convert=False,
numeric_only=True,
min_count=-1):

_local_template = """
Expand All @@ -1312,17 +1321,30 @@ def f(self, **kwargs):
kwargs['min_count'] = min_count

self._set_group_selection()

# try a cython aggregation if we can
try:
return self._cython_agg_general(
alias, alt=npfunc, **kwargs)
except AssertionError as e:
raise SpecificationError(str(e))
except Exception:
result = self.aggregate(
lambda x: npfunc(x, axis=self.axis))
if _convert:
result = result._convert(datetime=True)
return result
pass

# apply a non-cython aggregation
result = self.aggregate(
lambda x: npfunc(x, axis=self.axis))

# coerce the resulting columns if we can
if isinstance(result, DataFrame):
for col in result.columns:
result[col] = self._try_cast(
result[col], self.obj[col])
else:
result = self._try_cast(
result, self.obj)

return result

set_function_name(f, name, cls)

Expand Down
6 changes: 3 additions & 3 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pandas.core.dtypes.common import (
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
is_timedelta64_dtype, needs_i8_conversion)
from pandas.core.dtypes.missing import _maybe_fill, isna

Expand Down Expand Up @@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values):
if is_categorical_dtype(values) or is_sparse(values):
raise NotImplementedError(
"categoricals are not support in cython ops ATM")
"{} are not support in cython ops".format(values.dtype))
elif is_datetime64_any_dtype(values):
if how in ['add', 'prod', 'cumsum', 'cumprod']:
raise NotImplementedError(
Expand Down
Loading