Skip to content

Commit f4ce38e

Browse files
committed
ENH: make sure return dtypes for nan funcs are consistent
1 parent 5852e72 commit f4ce38e

File tree

3 files changed

+56
-42
lines changed

3 files changed

+56
-42
lines changed

pandas/core/nanops.py

+37-26
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ def nanall(values, axis=None, skipna=True):
244244
@bottleneck_switch(zero_value=0)
245245
def nansum(values, axis=None, skipna=True):
246246
values, mask, dtype, dtype_max = _get_values(values, skipna, 0)
247-
the_sum = values.sum(axis, dtype=dtype_max)
247+
dtype_sum = dtype_max
248+
if is_float_dtype(dtype):
249+
dtype_sum = dtype
250+
the_sum = values.sum(axis, dtype=dtype_sum)
248251
the_sum = _maybe_null_out(the_sum, axis, mask)
249252

250253
return _wrap_results(the_sum, dtype)
@@ -288,7 +291,7 @@ def get_median(x):
288291
return np.nan
289292
return algos.median(_values_from_object(x[mask]))
290293

291-
if values.dtype != np.float64:
294+
if not is_float_dtype(values):
292295
values = values.astype('f8')
293296
values[mask] = np.nan
294297

@@ -317,10 +320,10 @@ def get_median(x):
317320
return _wrap_results(get_median(values) if notempty else np.nan, dtype)
318321

319322

320-
def _get_counts_nanvar(mask, axis, ddof):
321-
count = _get_counts(mask, axis)
322-
323-
d = count-ddof
323+
def _get_counts_nanvar(mask, axis, ddof, dtype=float):
324+
dtype = _get_dtype(dtype)
325+
count = _get_counts(mask, axis, dtype=dtype)
326+
d = count - dtype.type(ddof)
324327

325328
# always return NaN, never inf
326329
if np.isscalar(count):
@@ -341,15 +344,19 @@ def _nanvar(values, axis=None, skipna=True, ddof=1):
341344
if is_any_int_dtype(values):
342345
values = values.astype('f8')
343346

344-
count, d = _get_counts_nanvar(mask, axis, ddof)
347+
if is_float_dtype(values):
348+
count, d = _get_counts_nanvar(mask, axis, ddof, values.dtype)
349+
else:
350+
count, d = _get_counts_nanvar(mask, axis, ddof)
345351

346352
if skipna:
347353
values = values.copy()
348354
np.putmask(values, mask, 0)
349355

350356
X = _ensure_numeric(values.sum(axis))
351357
XX = _ensure_numeric((values ** 2).sum(axis))
352-
return np.fabs((XX - X ** 2 / count) / d)
358+
result = np.fabs((XX - X * X / count) / d)
359+
return result
353360

354361
@disallow('M8')
355362
@bottleneck_switch(ddof=1)
@@ -375,9 +382,9 @@ def nansem(values, axis=None, skipna=True, ddof=1):
375382
mask = isnull(values)
376383
if not is_floating_dtype(values):
377384
values = values.astype('f8')
378-
count, _ = _get_counts_nanvar(mask, axis, ddof)
385+
count, _ = _get_counts_nanvar(mask, axis, ddof, values.dtype)
379386

380-
return np.sqrt(var)/np.sqrt(count)
387+
return np.sqrt(var) / np.sqrt(count)
381388

382389

383390
@bottleneck_switch()
@@ -467,25 +474,27 @@ def nanargmin(values, axis=None, skipna=True):
467474
def nanskew(values, axis=None, skipna=True):
468475

469476
mask = isnull(values)
470-
if not is_floating_dtype(values):
477+
if not is_float_dtype(values):
471478
values = values.astype('f8')
472-
473-
count = _get_counts(mask, axis)
479+
count = _get_counts(mask, axis)
480+
else:
481+
count = _get_counts(mask, axis, dtype=values.dtype)
474482

475483
if skipna:
476484
values = values.copy()
477485
np.putmask(values, mask, 0)
478486

487+
typ = values.dtype.type
479488
A = values.sum(axis) / count
480-
B = (values ** 2).sum(axis) / count - A ** 2
481-
C = (values ** 3).sum(axis) / count - A ** 3 - 3 * A * B
489+
B = (values ** 2).sum(axis) / count - A ** typ(2)
490+
C = (values ** 3).sum(axis) / count - A ** typ(3) - typ(3) * A * B
482491

483492
# floating point error
484493
B = _zero_out_fperr(B)
485494
C = _zero_out_fperr(C)
486495

487-
result = ((np.sqrt((count ** 2 - count)) * C) /
488-
((count - 2) * np.sqrt(B) ** 3))
496+
result = ((np.sqrt(count * count - count) * C) /
497+
((count - typ(2)) * np.sqrt(B) ** typ(3)))
489498

490499
if isinstance(result, np.ndarray):
491500
result = np.where(B == 0, 0, result)
@@ -502,19 +511,21 @@ def nanskew(values, axis=None, skipna=True):
502511
def nankurt(values, axis=None, skipna=True):
503512

504513
mask = isnull(values)
505-
if not is_floating_dtype(values):
514+
if not is_float_dtype(values):
506515
values = values.astype('f8')
507-
508-
count = _get_counts(mask, axis)
516+
count = _get_counts(mask, axis)
517+
else:
518+
count = _get_counts(mask, axis, dtype=values.dtype)
509519

510520
if skipna:
511521
values = values.copy()
512522
np.putmask(values, mask, 0)
513523

524+
typ = values.dtype.type
514525
A = values.sum(axis) / count
515-
B = (values ** 2).sum(axis) / count - A ** 2
516-
C = (values ** 3).sum(axis) / count - A ** 3 - 3 * A * B
517-
D = (values ** 4).sum(axis) / count - A ** 4 - 6 * B * A * A - 4 * C * A
526+
B = (values ** 2).sum(axis) / count - A ** typ(2)
527+
C = (values ** 3).sum(axis) / count - A ** typ(3) - typ(3) * A * B
528+
D = (values ** 4).sum(axis) / count - A ** typ(4) - typ(6) * B * A * A - typ(4) * C * A
518529

519530
B = _zero_out_fperr(B)
520531
D = _zero_out_fperr(D)
@@ -526,8 +537,8 @@ def nankurt(values, axis=None, skipna=True):
526537
if B == 0:
527538
return 0
528539

529-
result = (((count * count - 1.) * D / (B * B) - 3 * ((count - 1.) ** 2)) /
530-
((count - 2.) * (count - 3.)))
540+
result = (((count * count - typ(1)) * D / (B * B) - typ(3) * ((count - typ(1)) ** typ(2))) /
541+
((count - typ(2)) * (count - typ(3))))
531542

532543
if isinstance(result, np.ndarray):
533544
result = np.where(B == 0, 0, result)
@@ -598,7 +609,7 @@ def _zero_out_fperr(arg):
598609
if isinstance(arg, np.ndarray):
599610
return np.where(np.abs(arg) < 1e-14, 0, arg)
600611
else:
601-
return 0 if np.abs(arg) < 1e-14 else arg
612+
return arg.dtype.type(0) if np.abs(arg) < 1e-14 else arg
602613

603614

604615
@disallow('M8','m8')

pandas/tests/test_nanops.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,18 @@ def test_nanmean_overflow(self):
340340
self.assertEqual(result, np_result)
341341
self.assertTrue(result.dtype == np.float64)
342342

343-
# check returned dtype
344-
for dtype in [np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]:
343+
def test_returned_dtype(self):
344+
from pandas import Series
345+
for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64, np.float128]:
345346
s = Series(range(10), dtype=dtype)
346-
result = s.mean()
347-
if is_integer_dtype(dtype):
348-
self.assertTrue(result.dtype == np.float64)
349-
else:
350-
self.assertTrue(result.dtype == dtype)
347+
for method in ['mean', 'std', 'var', 'skew', 'kurt']:
348+
result = getattr(s, method)()
349+
if is_integer_dtype(dtype):
350+
self.assertTrue(result.dtype == np.float64,
351+
"return dtype expected from %s is np.float64, got %s instead" % (method, result.dtype))
352+
else:
353+
self.assertTrue(result.dtype == dtype,
354+
"return dtype expected from %s is %s, got %s instead" % (method, dtype, result.dtype))
351355

352356
def test_nanmedian(self):
353357
self.check_funs(nanops.nanmedian, np.median,

pandas/tests/test_series.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,6 @@ def test_nansum_buglet(self):
511511
assert_almost_equal(result, 1)
512512

513513
def test_overflow(self):
514-
515514
# GH 6915
516515
# overflowing on the smaller int dtypes
517516
for dtype in ['int32','int64']:
@@ -534,25 +533,25 @@ def test_overflow(self):
534533
result = s.max()
535534
self.assertEqual(int(result),v[-1])
536535

537-
for dtype in ['float32','float64']:
538-
v = np.arange(5000000,dtype=dtype)
536+
for dtype in ['float32', 'float64']:
537+
v = np.arange(5000000, dtype=dtype)
539538
s = Series(v)
540539

541540
# no bottleneck
542541
result = s.sum(skipna=False)
543-
self.assertTrue(np.allclose(float(result),v.sum(dtype='float64')))
542+
self.assertEqual(result, v.sum(dtype=dtype))
544543
result = s.min(skipna=False)
545-
self.assertTrue(np.allclose(float(result),0.0))
544+
self.assertTrue(np.allclose(float(result), 0.0))
546545
result = s.max(skipna=False)
547-
self.assertTrue(np.allclose(float(result),v[-1]))
546+
self.assertTrue(np.allclose(float(result), v[-1]))
548547

549548
# use bottleneck if available
550549
result = s.sum()
551-
self.assertTrue(np.allclose(float(result),v.sum(dtype='float64')))
550+
self.assertEqual(result, v.sum(dtype=dtype))
552551
result = s.min()
553-
self.assertTrue(np.allclose(float(result),0.0))
552+
self.assertTrue(np.allclose(float(result), 0.0))
554553
result = s.max()
555-
self.assertTrue(np.allclose(float(result),v[-1]))
554+
self.assertTrue(np.allclose(float(result), v[-1]))
556555

557556
class SafeForSparse(object):
558557
pass

0 commit comments

Comments
 (0)