Skip to content

Commit 7f2f70c

Browse files
committed
ENH: refactoring, generic GroupBy.apply with inference
1 parent 83b216c commit 7f2f70c

File tree

3 files changed

+151
-48
lines changed

3 files changed

+151
-48
lines changed

pandas/core/groupby.py

+146-45
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pandas.core.internals import BlockManager
1111
from pandas.core.series import Series
1212
from pandas.core.panel import WidePanel
13+
from pandas.util.decorators import cache_readonly
1314
import pandas._tseries as _tseries
1415

1516

@@ -61,7 +62,8 @@ def name(self):
6162
else:
6263
return self._name
6364

64-
def _get_obj_with_exclusions(self):
65+
@property
66+
def _obj_with_exclusions(self):
6567
return self.obj
6668

6769
@property
@@ -83,14 +85,14 @@ def __getattribute__(self, attr):
8385
def _make_wrapper(self, name):
8486
f = getattr(self.obj, name)
8587
if not isinstance(f, types.MethodType):
86-
return self.aggregate(lambda self: getattr(self, name))
88+
return self.apply(lambda self: getattr(self, name))
8789

8890
f = getattr(type(self.obj), name)
8991

9092
def wrapper(*args, **kwargs):
9193
def curried(x):
9294
return f(x, *args, **kwargs)
93-
return self.aggregate(curried)
95+
return self.apply(curried)
9496

9597
return wrapper
9698

@@ -112,7 +114,7 @@ def __iter__(self):
112114
113115
Returns
114116
-------
115-
Generator yielding sequence of (groupName, subsetted object)
117+
Generator yielding sequence of (name, subsetted object)
116118
for each group
117119
"""
118120
if len(self.groupings) == 1:
@@ -131,12 +133,12 @@ def __iter__(self):
131133

132134
def _multi_iter(self):
133135
tipo = type(self.obj)
134-
if isinstance(self.obj, DataFrame):
135-
data = self.obj
136-
elif isinstance(self.obj, NDFrame):
136+
data = self.obj
137+
if (isinstance(self.obj, NDFrame) and
138+
not isinstance(self.obj, DataFrame)):
137139
data = self.obj._data
138-
else:
139-
data = self.obj
140+
elif isinstance(self.obj, Series):
141+
tipo = Series
140142

141143
def flatten(gen, level=0):
142144
ids = self.groupings[level].ids
@@ -154,6 +156,12 @@ def flatten(gen, level=0):
154156
for cats, data in flatten(gen):
155157
yield cats + (data,)
156158

159+
def apply(self, func):
160+
"""
161+
Apply function, combine results together
162+
"""
163+
return self._python_apply_general(func)
164+
157165
def aggregate(self, func):
158166
raise NotImplementedError
159167

@@ -243,7 +251,7 @@ def _doit(reschunk, ctchunk, gen, shape_axis=0):
243251
output = np.empty(group_shape + stride_shape,
244252
dtype=float)
245253
output.fill(np.nan)
246-
obj = self._get_obj_with_exclusions()
254+
obj = self._obj_with_exclusions
247255
_doit(output, counts, gen_factory(obj),
248256
shape_axis=self.axis)
249257

@@ -267,6 +275,37 @@ def _doit(reschunk, ctchunk, gen, shape_axis=0):
267275

268276
return self._wrap_aggregated_output(output, mask)
269277

278+
def _python_apply_general(self, arg):
279+
result_keys = []
280+
result_values = []
281+
282+
key_as_tuple = len(self.groupings) > 1
283+
284+
not_indexed_same = False
285+
286+
for data in self:
287+
if key_as_tuple:
288+
key = data[:-1]
289+
else:
290+
key = data[0]
291+
292+
group = data[-1]
293+
group.name = key
294+
295+
res = arg(group)
296+
297+
if not _is_indexed_like(res, group):
298+
not_indexed_same = True
299+
300+
result_keys.append(key)
301+
result_values.append(res)
302+
303+
return self._wrap_applied_output(result_keys, result_values,
304+
not_indexed_same=not_indexed_same)
305+
306+
def _wrap_applied_output(self, *args, **kwargs):
307+
raise NotImplementedError
308+
270309
@property
271310
def _generator_factory(self):
272311
labels = [ping.labels for ping in self.groupings]
@@ -282,6 +321,12 @@ def _generator_factory(self):
282321
return lambda obj: generate_groups(obj, labels, shape, axis=axis,
283322
factory=factory)
284323

324+
def _is_indexed_like(obj, other):
325+
if isinstance(obj, Series):
326+
return obj.index.equals(other.index)
327+
elif isinstance(obj, DataFrame):
328+
return obj._indexed_same(other)
329+
285330
class Grouping(object):
286331

287332
def __init__(self, index, grouper=None, name=None, level=None):
@@ -470,6 +515,29 @@ def _wrap_aggregated_output(self, output, mask):
470515
name_list = self._get_names()
471516
return Series(output, index=name_list[0][1])
472517

518+
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
519+
if len(keys) == 0:
520+
return Series([])
521+
522+
if isinstance(values[0], Series):
523+
if not_indexed_same:
524+
data_dict = dict(zip(keys, values))
525+
result = DataFrame(data_dict).T
526+
if len(self.groupings) > 1:
527+
result.index = MultiIndex.from_tuples(keys)
528+
return result
529+
else:
530+
cat_values = np.concatenate([x.values for x in values])
531+
cat_index = np.concatenate([np.asarray(x.index)
532+
for x in values])
533+
return Series(cat_values, index=cat_index)
534+
else:
535+
if len(self.groupings) > 1:
536+
index = MultiIndex.from_tuples(keys)
537+
return Series(values, index)
538+
else:
539+
return Series(values, keys)
540+
473541
def _aggregate_multiple_funcs(self, arg):
474542
if not isinstance(arg, dict):
475543
arg = dict((func.__name__, func) for func in arg)
@@ -498,13 +566,13 @@ def _aggregate_named(self, arg):
498566

499567
for name in self.primary:
500568
grp = self.get_group(name)
501-
grp.groupName = name
569+
grp.name = name
502570
output = arg(grp)
503571
result[name] = output
504572

505573
return result
506574

507-
def transform(self, applyfunc):
575+
def transform(self, func):
508576
"""
509577
For given Series, group index by given mapper function or dict, take
510578
the sub-Series (reindex) for this group and call apply(applyfunc)
@@ -527,8 +595,8 @@ def transform(self, applyfunc):
527595
528596
Example
529597
-------
530-
series.fgroupby(lambda x: mapping[x],
531-
lambda x: (x - mean(x)) / std(x))
598+
series.transform(lambda x: mapping[x],
599+
lambda x: (x - x.mean()) / x.std())
532600
533601
Returns
534602
-------
@@ -538,9 +606,8 @@ def transform(self, applyfunc):
538606

539607
for name, group in self:
540608
# XXX
541-
group.groupName = name
542-
res = applyfunc(group)
543-
609+
group.name = name
610+
res = func(group)
544611
indexer, _ = self.obj.index.get_indexer(group.index)
545612
np.put(result, indexer, res)
546613

@@ -600,7 +667,8 @@ def _iterate_slices(self):
600667

601668
yield val, slicer(val)
602669

603-
def _get_obj_with_exclusions(self):
670+
@cache_readonly
671+
def _obj_with_exclusions(self):
604672
if len(self.exclusions) > 0:
605673
return self.obj.drop(self.exclusions, axis=1)
606674
else:
@@ -641,7 +709,7 @@ def aggregate(self, arg):
641709
def _aggregate_generic(self, agger, axis=0):
642710
result = {}
643711

644-
obj = self._get_obj_with_exclusions()
712+
obj = self._obj_with_exclusions
645713

646714
try:
647715
for name in self.primary:
@@ -668,7 +736,7 @@ def _aggregate_generic(self, agger, axis=0):
668736
def _aggregate_item_by_item(self, agger):
669737
# only for axis==0
670738

671-
obj = self._get_obj_with_exclusions()
739+
obj = self._obj_with_exclusions
672740

673741
result = {}
674742
cannot_agg = []
@@ -694,6 +762,30 @@ def _wrap_aggregated_output(self, output, mask):
694762

695763
return result
696764

765+
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
766+
if len(keys) == 0:
767+
# XXX
768+
return DataFrame({})
769+
770+
if isinstance(values[0], DataFrame):
771+
return _concat_frames(values)
772+
else:
773+
if len(self.groupings) > 1:
774+
keys = MultiIndex.from_tuples(keys)
775+
776+
# obj = self._obj_with_exclusions
777+
778+
if self.axis == 0:
779+
stacked_values = np.vstack(values)
780+
columns = values[0].index
781+
index = keys
782+
else:
783+
stacked_values = np.vstack(values)
784+
index = values[0].index
785+
columns = keys
786+
787+
return DataFrame(stacked_values, index=index, columns=columns)
788+
697789
def transform(self, func):
698790
"""
699791
For given DataFrame, group index by given mapper function or dict, take
@@ -715,8 +807,8 @@ def transform(self, func):
715807
716808
Note
717809
----
718-
Each subframe is endowed the attribute 'groupName' in case
719-
you need to know which group you are working on.
810+
Each subframe is endowed the attribute 'name' in case you need to know
811+
which group you are working on.
720812
721813
Example
722814
--------
@@ -725,41 +817,49 @@ def transform(self, func):
725817
"""
726818
applied = []
727819

728-
obj = self._get_obj_with_exclusions()
729-
for val, inds in self.primary.indices.iteritems():
730-
subframe = obj.take(inds, axis=self.axis)
731-
subframe.groupName = val
820+
obj = self._obj_with_exclusions
821+
for name, group in self:
822+
group.name = name
732823

733824
try:
734-
res = subframe.apply(func, axis=self.axis)
825+
res = group.apply(func, axis=self.axis)
735826
except Exception: # pragma: no cover
736-
res = func(subframe)
827+
res = func(group)
737828

738829
# broadcasting
739830
if isinstance(res, Series):
740831
if res.index is obj.index:
741-
subframe.T.values[:] = res
832+
group.T.values[:] = res
742833
else:
743-
subframe.values[:] = res
834+
group.values[:] = res
744835

745-
applied.append(subframe)
836+
applied.append(group)
746837
else:
747838
applied.append(res)
748839

749-
if self.axis == 0:
750-
all_index = [np.asarray(x.index) for x in applied]
751-
new_index = Index(np.concatenate(all_index))
752-
new_columns = obj.columns
753-
else:
754-
all_columns = [np.asarray(x.columns) for x in applied]
755-
new_columns = Index(np.concatenate(all_columns))
756-
new_index = obj.index
840+
return _concat_frames(applied, obj.index, obj.columns,
841+
axis=self.axis)
757842

758-
new_values = np.concatenate([x.values for x in applied],
759-
axis=self.axis)
760-
result = DataFrame(new_values, index=new_index, columns=new_columns)
761-
return result.reindex(index=obj.index, columns=obj.columns)
843+
def _concat_frames(frames, index=None, columns=None, axis=0):
844+
if axis == 0:
845+
all_index = [np.asarray(x.index) for x in frames]
846+
new_index = Index(np.concatenate(all_index))
762847

848+
if columns is None:
849+
new_columns = frames[0].columns
850+
else:
851+
new_columns = columns
852+
else:
853+
all_columns = [np.asarray(x.columns) for x in frames]
854+
new_columns = Index(np.concatenate(all_columns))
855+
if index is None:
856+
new_index = frames[0].index
857+
else:
858+
new_index = index
859+
860+
new_values = np.concatenate([x.values for x in frames], axis=axis)
861+
result = DataFrame(new_values, index=new_index, columns=new_columns)
862+
return result.reindex(index=index, columns=columns)
763863

764864
class WidePanelGroupBy(GroupBy):
765865

@@ -788,7 +888,7 @@ def aggregate(self, func):
788888
def _aggregate_generic(self, agger, axis=0):
789889
result = {}
790890

791-
obj = self._get_obj_with_exclusions()
891+
obj = self._obj_with_exclusions
792892

793893
for name in self.primary:
794894
data = self.get_group(name, obj=obj)
@@ -804,7 +904,8 @@ def _aggregate_generic(self, agger, axis=0):
804904

805905
return result
806906

807-
class LongPanelGroupBy(GroupBy):
907+
908+
class NDArrayGroupBy(GroupBy):
808909
pass
809910

810911
#-------------------------------------------------------------------------------

pandas/tests/test_groupby.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_basic(self):
111111
1 : 20,
112112
2 : 30
113113
}
114-
agged = grouped.agg(lambda x: group_constants[x.groupName] + x.mean())
114+
agged = grouped.agg(lambda x: group_constants[x.name] + x.mean())
115115
self.assertEqual(agged[1], 21)
116116

117117
# corner cases
@@ -180,8 +180,9 @@ def test_transform(self):
180180
def test_dispatch_transform(self):
181181
df = self.tsframe[::5].reindex(self.tsframe.index)
182182

183-
filled = df.groupby(lambda x: x.month).fillna(method='pad')
183+
grouped = df.groupby(lambda x: x.month)
184184

185+
filled = grouped.fillna(method='pad')
185186
fillit = lambda x: x.fillna(method='pad')
186187
expected = df.groupby(lambda x: x.month).transform(fillit)
187188
assert_frame_equal(filled, expected)

scripts/groupby_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@
8686
# print 'got'
8787
# print result
8888

89-
# tm.N = 10000
89+
tm.N = 10000
9090

9191
mapping = {'A': 0, 'C': 1, 'B': 0, 'D': 1}
9292
tf = lambda x: x - x.mean()
9393

9494
df = tm.makeTimeDataFrame()
95+
ts = df['A']
9596

9697
# grouped = df.groupby(lambda x: x.strftime('%m/%y'))
9798
grouped = df.groupby(mapping, axis=1)

0 commit comments

Comments
 (0)