From 022ab401848c231fa744bc744679cf1d1793195f Mon Sep 17 00:00:00 2001 From: Andy Hayden Date: Fri, 7 Mar 2014 17:16:32 -0800 Subject: [PATCH] FIX filter selects selected columns TST for selected groupby add resample ohlc and filter --- pandas/core/groupby.py | 2 +- pandas/tests/test_groupby.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 1bdb3973ee92c..86590d2319447 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -2529,7 +2529,7 @@ def filter(self, func, dropna=True, *args, **kwargs): indices = [] - obj = self._obj_with_exclusions + obj = self._selected_obj gen = self.grouper.get_iterator(obj, axis=self.axis) fast_path, slow_path = self._define_paths(func, *args, **kwargs) diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 3b613bb1705a3..adca8389b8939 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -3438,6 +3438,13 @@ def test_filter_and_transform_with_non_unique_string_index(self): actual = grouped_df.pid.transform(len) assert_series_equal(actual, expected) + def test_filter_has_access_to_grouped_cols(self): + df = DataFrame([[1, 2], [1, 3], [5, 6]], columns=['A', 'B']) + g = df.groupby('A') + # previously didn't have access to col A #???? + filt = g.filter(lambda x: x['A'].sum() == 2) + assert_frame_equal(filt, df.iloc[[0, 1]]) + def test_index_label_overlaps_location(self): # checking we don't have any label/location confusion in the # the wake of GH5375 @@ -3486,7 +3493,8 @@ def test_groupby_selection_with_methods(self): 'idxmin', 'idxmax', 'ffill', 'bfill', 'pct_change', - 'tshift' + 'tshift', + #'ohlc' ] for m in methods: @@ -3501,8 +3509,11 @@ def test_groupby_selection_with_methods(self): g_exp.apply(lambda x: x.sum())) assert_frame_equal(g.resample('D'), g_exp.resample('D')) + assert_frame_equal(g.resample('D', how='ohlc'), + g_exp.resample('D', how='ohlc')) - + assert_frame_equal(g.filter(lambda x: len(x) == 3), + g_exp.filter(lambda x: len(x) == 3)) def test_groupby_whitelist(self): from string import ascii_lowercase