From 8234eace4e408e5b55171a06fd71af09ec2e32eb Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 14 Nov 2019 16:12:08 -0800 Subject: [PATCH 1/6] typing in groupby --- pandas/core/groupby/generic.py | 49 +++++++++++++++++++--------------- pandas/core/groupby/groupby.py | 12 ++++----- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 002d8640f109d..6c6a91c1bfd17 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -334,7 +334,9 @@ def _aggregate_multiple_funcs(self, arg, _level): return DataFrame(results, columns=columns) - def _wrap_series_output(self, output, index, names=None): + def _wrap_series_output( + self, output: dict, index: Index, names=None + ) -> Union[Series, DataFrame]: """ common agg/transform wrapping logic """ output = output[self._selection_name] @@ -346,18 +348,22 @@ def _wrap_series_output(self, output, index, names=None): name = self._selected_obj.name return Series(output, index=index, name=name) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output( + self, output: dict, names=None + ) -> Union[Series, DataFrame]: result = self._wrap_series_output( output=output, index=self.grouper.result_index, names=names ) return self._reindex_output(result)._convert(datetime=True) - def _wrap_transformed_output(self, output, names=None): + def _wrap_transformed_output( + self, output: dict, names=None + ) -> Union[Series, DataFrame]: return self._wrap_series_output( output=output, index=self.obj.index, names=names ) - def _wrap_applied_output(self, keys, values, not_indexed_same=False): + def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False): if len(keys) == 0: # GH #6265 return Series([], name=self._selection_name, index=keys) @@ -389,8 +395,8 @@ def _get_index() -> Index: result = Series(data=values, index=_get_index(), name=self._selection_name) return self._reindex_output(result) - def _aggregate_named(self, func, *args, **kwargs): - result = OrderedDict() + def _aggregate_named(self, func, *args, **kwargs) -> OrderedDict: + result = OrderedDict() # type: OrderedDict for name, group in self: group.name = name @@ -455,18 +461,16 @@ def transform(self, func, *args, **kwargs): result.index = self._selected_obj.index return result - def _transform_fast(self, func, func_nm) -> Series: + def _transform_fast(self, func: Callable, func_nm: str) -> Series: """ fast version of transform, only applicable to builtin/cythonizable functions """ - if isinstance(func, str): - func = getattr(self, func) ids, _, ngroup = self.grouper.group_info - cast = self._transform_should_cast(func_nm) + should_cast = self._transform_should_cast(func_nm) out = algorithms.take_1d(func()._values, ids) - if cast: + if should_cast: out = self._try_cast(out, self.obj) return Series(out, index=self.obj.index, name=self.obj.name) @@ -1081,6 +1085,7 @@ def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame: def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame: # only for axis==0 + should_cast = self._transform_should_cast(func) obj = self._obj_with_exclusions result = OrderedDict() # type: dict cannot_agg = [] @@ -1089,9 +1094,8 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame: data = obj[item] colg = SeriesGroupBy(data, selection=item, grouper=self.grouper) - cast = self._transform_should_cast(func) try: - result[item] = colg.aggregate(func, *args, **kwargs) + res = colg.aggregate(func, *args, **kwargs) except ValueError as err: if "Must produce aggregated value" in str(err): @@ -1103,8 +1107,9 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame: continue else: - if cast: - result[item] = self._try_cast(result[item], data) + if should_cast: + res = self._try_cast(res, data) + result[item] = res result_columns = obj.columns if cannot_agg: @@ -1127,7 +1132,7 @@ def _decide_output_index(self, output, labels): return output_keys - def _wrap_applied_output(self, keys, values, not_indexed_same=False): + def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False): if len(keys) == 0: return DataFrame(index=keys) @@ -1379,13 +1384,15 @@ def transform(self, func, *args, **kwargs): return self._transform_fast(result, obj, func) - def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFrame: + def _transform_fast( + self, result: DataFrame, obj: DataFrame, func_nm: str + ) -> DataFrame: """ Fast transform path for aggregations """ # if there were groups with no observations (Categorical only?) # try casting data to original dtype - cast = self._transform_should_cast(func_nm) + should_cast = self._transform_should_cast(func_nm) # for each col, reshape to to size of original frame # by take operation @@ -1393,7 +1400,7 @@ def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFra output = [] for i, _ in enumerate(result.columns): res = algorithms.take_1d(result.iloc[:, i].values, ids) - if cast: + if should_cast: res = self._try_cast(res, obj.iloc[:, i]) output.append(res) @@ -1591,7 +1598,7 @@ def _insert_inaxis_grouper_inplace(self, result): if in_axis: result.insert(0, name, lev) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output(self, output: dict, names=None): agg_axis = 0 if self.axis == 1 else 1 agg_labels = self._obj_with_exclusions._get_axis(agg_axis) @@ -1610,7 +1617,7 @@ def _wrap_aggregated_output(self, output, names=None): return self._reindex_output(result)._convert(datetime=True) - def _wrap_transformed_output(self, output, names=None) -> DataFrame: + def _wrap_transformed_output(self, output: dict, names=None) -> DataFrame: return DataFrame(output, index=self.obj.index) def _wrap_agged_blocks(self, items, blocks): diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 280f1e88b0ea8..8fee6c64cb5b7 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -831,6 +831,7 @@ def _transform_should_cast(self, func_nm: str) -> bool: ) def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs): + should_cast = self._transform_should_cast(how) output = collections.OrderedDict() # type: dict for name, obj in self._iterate_slices(): is_numeric = is_numeric_dtype(obj.dtype) @@ -841,20 +842,19 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs): result, names = self.grouper.transform(obj.values, how, **kwargs) except NotImplementedError: continue - if self._transform_should_cast(how): - output[name] = self._try_cast(result, obj) - else: - output[name] = result + if should_cast: + result = self._try_cast(result, obj) + output[name] = result if len(output) == 0: raise DataError("No numeric types to aggregate") return self._wrap_transformed_output(output, names) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output(self, output: dict, names=None): raise AbstractMethodError(self) - def _wrap_transformed_output(self, output, names=None): + def _wrap_transformed_output(self, output: dict, names=None): raise AbstractMethodError(self) def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False): From efc67016cf332d057bb77c9b12c9301ce15a57b1 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 14 Nov 2019 16:22:01 -0800 Subject: [PATCH 2/6] annotations --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 6c6a91c1bfd17..7f11f6e34a19f 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -595,7 +595,7 @@ def describe(self, **kwargs): return result.unstack() def value_counts( - self, normalize=False, sort=True, ascending=False, bins=None, dropna=True + self, normalize: bool = False, sort: bool = True, ascending: bool = False, bins=None, dropna: bool = True ): from pandas.core.reshape.tile import cut From 836dc1a1329da68b12e8fdfcf65b85d3d5fe4be1 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 14 Nov 2019 16:31:06 -0800 Subject: [PATCH 3/6] blackify --- pandas/core/groupby/generic.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 7f11f6e34a19f..249a035f36b47 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -595,7 +595,12 @@ def describe(self, **kwargs): return result.unstack() def value_counts( - self, normalize: bool = False, sort: bool = True, ascending: bool = False, bins=None, dropna: bool = True + self, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + bins=None, + dropna: bool = True, ): from pandas.core.reshape.tile import cut From a91a489e30e7023ffa5a3aca23bd6cd383d83136 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Sat, 16 Nov 2019 14:42:41 -0800 Subject: [PATCH 4/6] OrderedDict->dict --- pandas/core/groupby/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index d0b22a3a5f179..ad7c1ae6d6c5d 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -383,8 +383,8 @@ def _get_index() -> Index: result = Series(data=values, index=_get_index(), name=self._selection_name) return self._reindex_output(result) - def _aggregate_named(self, func, *args, **kwargs) -> OrderedDict: - result = OrderedDict() # type: OrderedDict + def _aggregate_named(self, func, *args, **kwargs) -> dict: + result = {} # type: dict for name, group in self: group.name = name From 2c67e86785b3c4cbf7a9b5ac4adbdaaaaf7e504c Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Sat, 16 Nov 2019 19:56:58 -0800 Subject: [PATCH 5/6] dict->Dict --- pandas/core/groupby/generic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index ad7c1ae6d6c5d..b7295ed2d4b0e 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -10,7 +10,7 @@ from functools import partial from textwrap import dedent import typing -from typing import Any, Callable, FrozenSet, Iterable, Sequence, Type, Union, cast +from typing import Any, Callable, Dict, FrozenSet, Iterable, Sequence, Type, Union, cast import warnings import numpy as np @@ -383,8 +383,8 @@ def _get_index() -> Index: result = Series(data=values, index=_get_index(), name=self._selection_name) return self._reindex_output(result) - def _aggregate_named(self, func, *args, **kwargs) -> dict: - result = {} # type: dict + def _aggregate_named(self, func, *args, **kwargs) -> Dict: + result = {} # type: Dict for name, group in self: group.name = name From 7fcda0673859b90713d8e2d6e242bba2c3159ed3 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Sun, 17 Nov 2019 07:40:45 -0800 Subject: [PATCH 6/6] update type syntax --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index b7295ed2d4b0e..3e578b4878312 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -384,7 +384,7 @@ def _get_index() -> Index: return self._reindex_output(result) def _aggregate_named(self, func, *args, **kwargs) -> Dict: - result = {} # type: Dict + result: Dict = {} for name, group in self: group.name = name