@@ -251,7 +251,7 @@ def test_groupby_levels_and_columns(self):
251251 by_columns .columns = pd .Index (by_columns .columns , dtype = np .int64 )
252252 tm .assert_frame_equal (by_levels , by_columns )
253253
254- def test_groupby_categorical_index_and_columns (self ):
254+ def test_groupby_categorical_index_and_columns (self , observed ):
255255 # GH18432
256256 columns = ['A' , 'B' , 'A' , 'B' ]
257257 categories = ['B' , 'A' ]
@@ -260,17 +260,26 @@ def test_groupby_categorical_index_and_columns(self):
260260 categories = categories ,
261261 ordered = True )
262262 df = DataFrame (data = data , columns = cat_columns )
263- result = df .groupby (axis = 1 , level = 0 ).sum ()
263+ result = df .groupby (axis = 1 , level = 0 , observed = observed ).sum ()
264264 expected_data = 2 * np .ones ((5 , 2 ), int )
265- expected_columns = CategoricalIndex (categories ,
266- categories = categories ,
267- ordered = True )
265+
266+ if observed :
267+ # if we are not-observed we undergo a reindex
268+ # so need to adjust the output as our expected sets us up
269+ # to be non-observed
270+ expected_columns = CategoricalIndex (['A' , 'B' ],
271+ categories = categories ,
272+ ordered = True )
273+ else :
274+ expected_columns = CategoricalIndex (categories ,
275+ categories = categories ,
276+ ordered = True )
268277 expected = DataFrame (data = expected_data , columns = expected_columns )
269278 assert_frame_equal (result , expected )
270279
271280 # test transposed version
272281 df = DataFrame (data .T , index = cat_columns )
273- result = df .groupby (axis = 0 , level = 0 ).sum ()
282+ result = df .groupby (axis = 0 , level = 0 , observed = observed ).sum ()
274283 expected = DataFrame (data = expected_data .T , index = expected_columns )
275284 assert_frame_equal (result , expected )
276285
@@ -572,11 +581,11 @@ def test_get_group(self):
572581 pytest .raises (ValueError ,
573582 lambda : g .get_group (('foo' , 'bar' , 'baz' )))
574583
575- def test_get_group_empty_bins (self ):
584+ def test_get_group_empty_bins (self , observed ):
576585
577586 d = pd .DataFrame ([3 , 1 , 7 , 6 ])
578587 bins = [0 , 5 , 10 , 15 ]
579- g = d .groupby (pd .cut (d [0 ], bins ))
588+ g = d .groupby (pd .cut (d [0 ], bins ), observed = observed )
580589
581590 # TODO: should prob allow a str of Interval work as well
582591 # IOW '(0, 5]'
0 commit comments