Skip to content

Commit ad9cffd

Browse files
committed
BUG: enable numpy reductions to be passed to how parameter in convert, TimeGrouper tests, close #1045
1 parent 63207b8 commit ad9cffd

File tree

2 files changed

+124
-59
lines changed

2 files changed

+124
-59
lines changed

pandas/core/groupby.py

+97-59
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,6 @@ def generate_bins_generic(values, binner, closed, label):
764764

765765
return bins, labels
766766

767-
class CustomGrouper:
768-
pass
769767

770768
def _generate_time_binner(dtindex, offset,
771769
begin=None, end=None, nperiods=None):
@@ -789,7 +787,90 @@ def _generate_time_binner(dtindex, offset,
789787

790788
return DatetimeIndex(freq=offset, start=first, end=last, periods=nperiods)
791789

792-
class TimeGrouper(Grouper, CustomGrouper):
790+
791+
class BinGrouper(Grouper):
792+
793+
def __init__(self, bins, binlabels):
794+
self.bins = bins
795+
self.binlabels = binlabels
796+
797+
def get_iterator(self, data, axis=0):
798+
"""
799+
Groupby iterator
800+
801+
Returns
802+
-------
803+
Generator yielding sequence of (name, subsetted object)
804+
for each group
805+
"""
806+
if axis == 1:
807+
raise NotImplementedError
808+
809+
start = 0
810+
for edge, label in zip(self.bins, self.binlabels):
811+
yield label, data[start:edge]
812+
start = edge
813+
814+
yield self.binlabels[-1], data[edge:]
815+
816+
@cache_readonly
817+
def ngroups(self):
818+
return len(self.binlabels)
819+
820+
#----------------------------------------------------------------------
821+
# cython aggregation
822+
823+
_cython_functions = {
824+
'add' : lib.group_add_bin,
825+
'mean' : lib.group_mean_bin,
826+
'var' : lib.group_var_bin,
827+
'std' : lib.group_var_bin,
828+
'ohlc' : lib.group_ohlc
829+
}
830+
831+
_cython_arity = {
832+
'ohlc' : 4, # OHLC
833+
}
834+
835+
def aggregate(self, values, how):
836+
values = com._ensure_float64(values)
837+
838+
agg_func = self._cython_functions[how]
839+
arity = self._cython_arity.get(how, 1)
840+
841+
if values.ndim == 1:
842+
squeeze = True
843+
values = values[:, None]
844+
out_shape = (self.ngroups, arity)
845+
else:
846+
squeeze = False
847+
out_shape = (self.ngroups, values.shape[1] * arity)
848+
849+
trans_func = self._cython_transforms.get(how, lambda x: x)
850+
851+
# will be filled in Cython function
852+
result = np.empty(out_shape, dtype=np.float64)
853+
counts = np.zeros(self.ngroups, dtype=np.int32)
854+
855+
agg_func(result, counts, values, self.bins)
856+
result = trans_func(result)
857+
858+
result = lib.row_bool_subset(result, counts > 0)
859+
860+
if squeeze:
861+
result = result.squeeze()
862+
863+
if how in self._name_functions:
864+
# TODO
865+
names = self._name_functions[how]()
866+
else:
867+
names = None
868+
869+
return result, names
870+
871+
872+
873+
class TimeGrouper(BinGrouper):
793874
"""
794875
Custom groupby class for time-interval grouping
795876
@@ -860,6 +941,14 @@ def set_obj(self, obj):
860941
self.bins = bins
861942
self.binlabels = labels.view('M8[us]')
862943

944+
@property
945+
def names(self):
946+
return [self.obj.index.name]
947+
948+
@property
949+
def levels(self):
950+
return [self.binlabels]
951+
863952
@cache_readonly
864953
def ngroups(self):
865954
return len(self.binlabels)
@@ -873,56 +962,6 @@ def agg_series(self, obj, func):
873962
grouper = lib.SeriesBinGrouper(obj, func, self.bins, dummy)
874963
return grouper.get_result()
875964

876-
#----------------------------------------------------------------------
877-
# cython aggregation
878-
879-
_cython_functions = {
880-
'add' : lib.group_add_bin,
881-
'mean' : lib.group_mean_bin,
882-
'var' : lib.group_var_bin,
883-
'std' : lib.group_var_bin,
884-
'ohlc' : lib.group_ohlc
885-
}
886-
887-
_cython_arity = {
888-
'ohlc' : 4, # OHLC
889-
}
890-
891-
def aggregate(self, values, how):
892-
values = com._ensure_float64(values)
893-
894-
agg_func = self._cython_functions[how]
895-
arity = self._cython_arity.get(how, 1)
896-
897-
if values.ndim == 1:
898-
squeeze = True
899-
values = values[:, None]
900-
out_shape = (self.ngroups, arity)
901-
else:
902-
squeeze = False
903-
out_shape = (self.ngroups, values.shape[1] * arity)
904-
905-
trans_func = self._cython_transforms.get(how, lambda x: x)
906-
907-
# will be filled in Cython function
908-
result = np.empty(out_shape, dtype=np.float64)
909-
counts = np.zeros(self.ngroups, dtype=np.int32)
910-
911-
agg_func(result, counts, values, self.bins)
912-
result = trans_func(result)
913-
914-
result = lib.row_bool_subset(result, counts > 0)
915-
916-
if squeeze:
917-
result = result.squeeze()
918-
919-
if how in self._name_functions:
920-
# TODO
921-
names = self._name_functions[how]()
922-
else:
923-
names = None
924-
925-
return result, names
926965

927966
class Grouping(object):
928967
"""
@@ -962,7 +1001,7 @@ def __init__(self, index, grouper=None, name=None, level=None,
9621001
self._was_factor = False
9631002

9641003
# did we pass a custom grouper object? Do nothing
965-
if isinstance(grouper, CustomGrouper):
1004+
if isinstance(grouper, Grouper):
9661005
return
9671006

9681007
if level is not None:
@@ -1063,7 +1102,7 @@ def _get_grouper(obj, key=None, axis=0, level=None, sort=True):
10631102
level = None
10641103
key = group_axis
10651104

1066-
if isinstance(key, CustomGrouper):
1105+
if isinstance(key, Grouper):
10671106
key.set_obj(obj)
10681107
return key, []
10691108

@@ -1263,10 +1302,9 @@ def _get_index():
12631302
def _aggregate_named(self, func, *args, **kwargs):
12641303
result = {}
12651304

1266-
for name in self.grouper:
1267-
grp = self.get_group(name)
1268-
grp.name = name
1269-
output = func(grp, *args, **kwargs)
1305+
for name, group in self:
1306+
group.name = name
1307+
output = func(group, *args, **kwargs)
12701308
if isinstance(output, np.ndarray):
12711309
raise Exception('Must produce aggregated value')
12721310
result[name] = output

pandas/tests/test_timeseries.py

+27
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,33 @@ def test_datetimeindex_union_join_empty(self):
918918

919919
# TODO: test merge & concat with datetime64 block
920920

921+
class TestTimeGrouper(unittest.TestCase):
922+
923+
def setUp(self):
924+
self.ts = Series(np.random.randn(1000),
925+
index=date_range('1/1/2000', periods=1000))
926+
927+
def test_apply(self):
928+
grouper = TimeGrouper('A', label='right', closed='right')
929+
930+
grouped = self.ts.groupby(grouper)
931+
932+
f = lambda x: x.order()[-3:]
933+
934+
applied = grouped.apply(f)
935+
expected = self.ts.groupby(lambda x: x.year).apply(f)
936+
937+
applied.index = applied.index.droplevel(0)
938+
expected.index = expected.index.droplevel(0)
939+
assert_series_equal(applied, expected)
940+
941+
def test_numpy_reduction(self):
942+
result = self.ts.convert('A', how=np.prod, closed='right')
943+
944+
expected = self.ts.groupby(lambda x: x.year).agg(np.prod)
945+
expected.index = result.index
946+
947+
assert_series_equal(result, expected)
921948

922949
class TestNewOffsets(unittest.TestCase):
923950

0 commit comments

Comments
 (0)