Skip to content

Commit 5d4ce99

Browse files
committed
Merge pull request #7974 from dsm054/allow_nameless_callables
BUG: Allow __name__less callables as groupby hows (GH7929)
2 parents 4eb5e51 + e7c2e93 commit 5d4ce99

File tree

6 files changed

+90
-7
lines changed

6 files changed

+90
-7
lines changed

doc/source/v0.15.0.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,8 @@ Bug Fixes
389389
- Bug in ``GroupBy.transform()`` where int groups with a transform that
390390
didn't preserve the index were incorrectly truncated (:issue:`7972`).
391391

392-
393-
394-
392+
- Bug in ``groupby`` where callable objects without name attributes would take the wrong path,
393+
and produce a ``DataFrame`` instead of a ``Series`` (:issue:`7929`)
395394

396395

397396
- Bug in ``read_html`` where the ``infer_types`` argument forced coercion of

pandas/core/common.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import csv
1010
import types
1111
from datetime import datetime, timedelta
12+
from functools import partial
1213

1314
from numpy.lib.format import read_array, write_array
1415
import numpy as np
@@ -2432,7 +2433,22 @@ def _is_sequence(x):
24322433
except (TypeError, AttributeError):
24332434
return False
24342435

2435-
2436+
def _get_callable_name(obj):
2437+
# typical case has name
2438+
if hasattr(obj, '__name__'):
2439+
return getattr(obj, '__name__')
2440+
# some objects don't; could recurse
2441+
if isinstance(obj, partial):
2442+
return _get_callable_name(obj.func)
2443+
# fall back to class name
2444+
if hasattr(obj, '__call__'):
2445+
return obj.__class__.__name__
2446+
# everything failed (probably because the argument
2447+
# wasn't actually callable); we return None
2448+
# instead of the empty string in this case to allow
2449+
# distinguishing between no name and a name of ''
2450+
return None
2451+
24362452
_string_dtypes = frozenset(map(_get_dtype_from_object, (compat.binary_type,
24372453
compat.text_type)))
24382454

pandas/core/groupby.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,8 @@ def apply(self, f, data, axis=0):
12211221
group_keys = self._get_group_keys()
12221222

12231223
# oh boy
1224-
if (f.__name__ not in _plotting_methods and
1224+
f_name = com._get_callable_name(f)
1225+
if (f_name not in _plotting_methods and
12251226
hasattr(splitter, 'fast_apply') and axis == 0):
12261227
try:
12271228
values, mutated = splitter.fast_apply(f, group_keys)
@@ -2185,11 +2186,11 @@ def _aggregate_multiple_funcs(self, arg):
21852186
if isinstance(f, compat.string_types):
21862187
columns.append(f)
21872188
else:
2188-
columns.append(f.__name__)
2189+
# protect against callables without names
2190+
columns.append(com._get_callable_name(f))
21892191
arg = lzip(columns, arg)
21902192

21912193
results = {}
2192-
21932194
for name, func in arg:
21942195
if name in results:
21952196
raise SpecificationError('Function names must be unique, '

pandas/tests/test_common.py

+20
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,26 @@ def __getitem__(self):
3838

3939
assert(not is_seq(A()))
4040

41+
def test_get_callable_name():
42+
from functools import partial
43+
getname = com._get_callable_name
44+
45+
def fn(x):
46+
return x
47+
lambda_ = lambda x: x
48+
part1 = partial(fn)
49+
part2 = partial(part1)
50+
class somecall(object):
51+
def __call__(self):
52+
return x
53+
54+
assert getname(fn) == 'fn'
55+
assert getname(lambda_)
56+
assert getname(part1) == 'fn'
57+
assert getname(part2) == 'fn'
58+
assert getname(somecall()) == 'somecall'
59+
assert getname(1) is None
60+
4161

4262
def test_notnull():
4363
assert notnull(1.)

pandas/tests/test_groupby.py

+21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pandas.core.panel import Panel
2626
from pandas.tools.merge import concat
2727
from collections import defaultdict
28+
from functools import partial
2829
import pandas.core.common as com
2930
import numpy as np
3031

@@ -2910,6 +2911,24 @@ def test_multi_function_flexible_mix(self):
29102911
assert_frame_equal(result, expected)
29112912
assert_frame_equal(result2, expected)
29122913

2914+
def test_agg_callables(self):
2915+
# GH 7929
2916+
df = DataFrame({'foo' : [1,2], 'bar' :[3,4]}).astype(np.int64)
2917+
2918+
class fn_class(object):
2919+
def __call__(self, x):
2920+
return sum(x)
2921+
2922+
equiv_callables = [sum, np.sum,
2923+
lambda x: sum(x),
2924+
lambda x: x.sum(),
2925+
partial(sum), fn_class()]
2926+
2927+
expected = df.groupby("foo").agg(sum)
2928+
for ecall in equiv_callables:
2929+
result = df.groupby('foo').agg(ecall)
2930+
assert_frame_equal(result, expected)
2931+
29132932
def test_set_group_name(self):
29142933
def f(group):
29152934
assert group.name is not None
@@ -4530,6 +4549,8 @@ def test_transform_doesnt_clobber_ints(self):
45304549
tm.assert_frame_equal(result, expected)
45314550

45324551

4552+
4553+
45334554
def assert_fp_equal(a, b):
45344555
assert (np.abs(a - b) < 1e-12).all()
45354556

pandas/tseries/tests/test_resample.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pylint: disable=E1101
22

33
from datetime import datetime, timedelta
4+
from functools import partial
45

56
from pandas.compat import range, lrange, zip, product
67
import numpy as np
@@ -140,6 +141,30 @@ def _ohlc(group):
140141
exc.args += ('how=%s' % arg,)
141142
raise
142143

144+
def test_resample_how_callables(self):
145+
# GH 7929
146+
data = np.arange(5, dtype=np.int64)
147+
ind = pd.DatetimeIndex(start='2014-01-01', periods=len(data), freq='d')
148+
df = pd.DataFrame({"A": data, "B": data}, index=ind)
149+
150+
def fn(x, a=1):
151+
return str(type(x))
152+
153+
class fn_class:
154+
def __call__(self, x):
155+
return str(type(x))
156+
157+
df_standard = df.resample("M", how=fn)
158+
df_lambda = df.resample("M", how=lambda x: str(type(x)))
159+
df_partial = df.resample("M", how=partial(fn))
160+
df_partial2 = df.resample("M", how=partial(fn, a=2))
161+
df_class = df.resample("M", how=fn_class())
162+
163+
assert_frame_equal(df_standard, df_lambda)
164+
assert_frame_equal(df_standard, df_partial)
165+
assert_frame_equal(df_standard, df_partial2)
166+
assert_frame_equal(df_standard, df_class)
167+
143168
def test_resample_basic_from_daily(self):
144169
# from daily
145170
dti = DatetimeIndex(
@@ -765,6 +790,7 @@ def test_resample_timegrouper(self):
765790
assert_frame_equal(result, expected)
766791

767792

793+
768794
def _simple_ts(start, end, freq='D'):
769795
rng = date_range(start, end, freq=freq)
770796
return Series(np.random.randn(len(rng)), index=rng)

0 commit comments

Comments
 (0)