Skip to content

Commit 7985efb

Browse files
committed
Complete rework
1 parent 53ae9d6 commit 7985efb

File tree

3 files changed

+103
-102
lines changed

3 files changed

+103
-102
lines changed

pandas/_libs/groupby.pyx

+47-53
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
378378
@cython.boundscheck(False)
379379
@cython.wraparound(False)
380380
def group_any_all(uint8_t[:] out,
381-
const int64_t[:] labels,
382381
const uint8_t[:] values,
382+
const int64_t[:] labels,
383383
const uint8_t[:] mask,
384384
object val_test,
385385
bint skipna):
@@ -560,7 +560,8 @@ def _group_var(floating[:, :] out,
560560
int64_t[:] counts,
561561
floating[:, :] values,
562562
const int64_t[:] labels,
563-
Py_ssize_t min_count=-1):
563+
Py_ssize_t min_count=-1,
564+
int64_t ddof=1):
564565
cdef:
565566
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
566567
floating val, ct, oldmean
@@ -600,10 +601,10 @@ def _group_var(floating[:, :] out,
600601
for i in range(ncounts):
601602
for j in range(K):
602603
ct = nobs[i, j]
603-
if ct < 2:
604+
if ct <= ddof:
604605
out[i, j] = NAN
605606
else:
606-
out[i, j] /= (ct - 1)
607+
out[i, j] /= (ct - ddof)
607608

608609

609610
group_var_float32 = _group_var['float']
@@ -714,12 +715,10 @@ group_ohlc_float64 = _group_ohlc['double']
714715

715716
@cython.boundscheck(False)
716717
@cython.wraparound(False)
717-
def group_quantile(floating[:, :] out,
718-
int64_t[:] counts,
719-
floating[:, :] values,
720-
const int64_t[:] labels,
721-
Py_ssize_t min_count,
722-
const uint8_t[:, :] mask,
718+
def group_quantile(ndarray[float64_t] out,
719+
numeric[:] values,
720+
ndarray[int64_t] labels,
721+
ndarray[uint8_t] mask,
723722
float64_t q,
724723
object interpolation):
725724
"""
@@ -742,12 +741,12 @@ def group_quantile(floating[:, :] out,
742741
provided `out` parameter.
743742
"""
744743
cdef:
745-
Py_ssize_t i, N=len(labels), K, ngroups, grp_sz=0, non_na_sz
744+
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
746745
Py_ssize_t grp_start=0, idx=0
747746
int64_t lab
748747
uint8_t interp
749748
float64_t q_idx, frac, val, next_val
750-
int64_t[:, :] non_na_counts, sort_arrs
749+
ndarray[int64_t] counts, non_na_counts, sort_arr
751750

752751
assert values.shape[0] == N
753752

@@ -763,64 +762,59 @@ def group_quantile(floating[:, :] out,
763762
}
764763
interp = inter_methods[interpolation]
765764

765+
counts = np.zeros_like(out, dtype=np.int64)
766766
non_na_counts = np.zeros_like(out, dtype=np.int64)
767-
sort_arrs = np.empty_like(values, dtype=np.int64)
768767
ngroups = len(counts)
769768

770-
N, K = (<object>values).shape
771-
772769
# First figure out the size of every group
773770
with nogil:
774771
for i in range(N):
775772
lab = labels[i]
776773
if lab == -1: # NA group label
777774
continue
775+
778776
counts[lab] += 1
779-
for j in range(K):
780-
if not mask[i, j]:
781-
non_na_counts[lab, j] += 1
777+
if not mask[i]:
778+
non_na_counts[lab] += 1
782779

783-
for j in range(K):
784-
order = (values[:, j], labels)
785-
r = np.lexsort(order).astype(np.int64, copy=False)
786-
# TODO: Need better way to assign r to column j
787-
for i in range(N):
788-
sort_arrs[i, j] = r[i]
780+
# Get an index of values sorted by labels and then values
781+
order = (values, labels)
782+
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
789783

790784
with nogil:
791785
for i in range(ngroups):
792786
# Figure out how many group elements there are
793787
grp_sz = counts[i]
794-
for j in range(K):
795-
non_na_sz = non_na_counts[i, j]
796-
if non_na_sz == 0:
797-
out[i, j] = NaN
788+
non_na_sz = non_na_counts[i]
789+
790+
if non_na_sz == 0:
791+
out[i] = NaN
792+
else:
793+
# Calculate where to retrieve the desired value
794+
# Casting to int will intentionally truncate result
795+
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
796+
797+
val = values[sort_arr[idx]]
798+
# If requested quantile falls evenly on a particular index
799+
# then write that index's value out. Otherwise interpolate
800+
q_idx = q * (non_na_sz - 1)
801+
frac = q_idx % 1
802+
803+
if frac == 0.0 or interp == INTERPOLATION_LOWER:
804+
out[i] = val
798805
else:
799-
# Calculate where to retrieve the desired value
800-
# Casting to int will intentionally truncate result
801-
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
802-
803-
val = values[sort_arrs[idx, j], j]
804-
# If requested quantile falls evenly on a particular index
805-
# then write that index's value out. Otherwise interpolate
806-
q_idx = q * (non_na_sz - 1)
807-
frac = q_idx % 1
808-
809-
if frac == 0.0 or interp == INTERPOLATION_LOWER:
810-
out[i, j] = val
811-
else:
812-
next_val = values[sort_arrs[idx + 1, j], j]
813-
if interp == INTERPOLATION_LINEAR:
814-
out[i, j] = val + (next_val - val) * frac
815-
elif interp == INTERPOLATION_HIGHER:
816-
out[i, j] = next_val
817-
elif interp == INTERPOLATION_MIDPOINT:
818-
out[i, j] = (val + next_val) / 2.0
819-
elif interp == INTERPOLATION_NEAREST:
820-
if frac > .5 or (frac == .5 and q > .5): # Always OK?
821-
out[i, j] = next_val
822-
else:
823-
out[i, j] = val
806+
next_val = values[sort_arr[idx + 1]]
807+
if interp == INTERPOLATION_LINEAR:
808+
out[i] = val + (next_val - val) * frac
809+
elif interp == INTERPOLATION_HIGHER:
810+
out[i] = next_val
811+
elif interp == INTERPOLATION_MIDPOINT:
812+
out[i] = (val + next_val) / 2.0
813+
elif interp == INTERPOLATION_NEAREST:
814+
if frac > .5 or (frac == .5 and q > .5): # Always OK?
815+
out[i] = next_val
816+
else:
817+
out[i] = val
824818

825819
# Increment the index reference in sorted_arr for the next group
826820
grp_start += grp_sz

pandas/core/groupby/generic.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,11 @@ def _wrap_aggregated_output(
17211721
DataFrame
17221722
"""
17231723
indexed_output = {key.position: val for key, val in output.items()}
1724-
columns = Index(key.label for key in output)
1724+
if self.axis == 0:
1725+
name = self._obj_with_exclusions.columns.name
1726+
else:
1727+
name = self._obj_with_exclusions.index.name
1728+
columns = Index([key.label for key in output], name=name)
17251729

17261730
result = self.obj._constructor(indexed_output)
17271731
result.columns = columns

pandas/core/groupby/groupby.py

+51-48
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,7 @@ def result_to_bool(result: np.ndarray, inference: Type) -> np.ndarray:
12601260
return self._get_cythonized_result(
12611261
"group_any_all",
12621262
aggregate=True,
1263+
numeric_only=False,
12631264
cython_dtype=np.dtype(np.uint8),
12641265
needs_values=True,
12651266
needs_mask=True,
@@ -1416,18 +1417,16 @@ def std(self, ddof: int = 1):
14161417
Series or DataFrame
14171418
Standard deviation of values within each group.
14181419
"""
1419-
result = self.var(ddof=ddof)
1420-
if result.ndim == 1:
1421-
result = np.sqrt(result)
1422-
else:
1423-
cols = result.columns.get_indexer_for(
1424-
result.columns.difference(self.exclusions).unique()
1425-
)
1426-
# TODO(GH-22046) - setting with iloc broken if labels are not unique
1427-
# .values to remove labels
1428-
result.iloc[:, cols] = np.sqrt(result.iloc[:, cols]).values
1429-
1430-
return result
1420+
return self._get_cythonized_result(
1421+
"group_var_float64",
1422+
aggregate=True,
1423+
needs_counts=True,
1424+
needs_values=True,
1425+
needs_2d=True,
1426+
cython_dtype=np.dtype(np.float64),
1427+
post_processing=lambda vals, inference: np.sqrt(vals),
1428+
ddof=ddof,
1429+
)
14311430

14321431
@Substitution(name="groupby")
14331432
@Appender(_common_see_also)
@@ -1756,6 +1755,7 @@ def _fill(self, direction, limit=None):
17561755

17571756
return self._get_cythonized_result(
17581757
"group_fillna_indexer",
1758+
numeric_only=False,
17591759
needs_mask=True,
17601760
cython_dtype=np.dtype(np.int64),
17611761
result_is_index=True,
@@ -2039,9 +2039,6 @@ def pre_processor(vals: np.ndarray) -> Tuple[np.ndarray, Optional[Type]]:
20392039
inference = "datetime64[ns]"
20402040
vals = np.asarray(vals).astype(np.float)
20412041

2042-
if vals.dtype != np.dtype(np.float64):
2043-
vals = vals.astype(np.float64)
2044-
20452042
return vals, inference
20462043

20472044
def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
@@ -2059,6 +2056,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
20592056
return self._get_cythonized_result(
20602057
"group_quantile",
20612058
aggregate=True,
2059+
numeric_only=False,
20622060
needs_values=True,
20632061
needs_mask=True,
20642062
cython_dtype=np.dtype(np.float64),
@@ -2348,7 +2346,11 @@ def _get_cythonized_result(
23482346
how: str,
23492347
cython_dtype: np.dtype,
23502348
aggregate: bool = False,
2349+
numeric_only: bool = True,
2350+
needs_counts: bool = False,
23512351
needs_values: bool = False,
2352+
needs_2d: bool = False,
2353+
min_count: Optional[int] = None,
23522354
needs_mask: bool = False,
23532355
needs_ngroups: bool = False,
23542356
result_is_index: bool = False,
@@ -2367,9 +2369,18 @@ def _get_cythonized_result(
23672369
aggregate : bool, default False
23682370
Whether the result should be aggregated to match the number of
23692371
groups
2372+
numeric_only : bool, default True
2373+
Whether only numeric datatypes should be computed
2374+
needs_counts : bool, default False
2375+
Whether the counts should be a part of the Cython call
23702376
needs_values : bool, default False
23712377
Whether the values should be a part of the Cython call
23722378
signature
2379+
needs_2d : bool, default False
2380+
Whether the values and result of the Cython call signature
2381+
are 2-dimensional.
2382+
min_count : int, default None
2383+
When not None, min_count for the Cython call
23732384
needs_mask : bool, default False
23742385
Whether boolean mask needs to be part of the Cython call
23752386
signature
@@ -2415,56 +2426,44 @@ def _get_cythonized_result(
24152426
output: Dict[base.OutputKey, np.ndarray] = {}
24162427
base_func = getattr(libgroupby, how)
24172428

2418-
if how == "group_quantile":
2419-
values = self._obj_with_exclusions._values
2420-
result_sz = ngroups if aggregate else len(values)
2421-
2422-
vals, inferences = pre_processing(values)
2423-
if self._obj_with_exclusions.ndim == 1:
2424-
width = 1
2425-
vals = np.reshape(vals, (-1, 1))
2426-
else:
2427-
width = len(self._obj_with_exclusions.columns)
2428-
result = np.zeros((result_sz, width), dtype=cython_dtype)
2429-
counts = np.zeros(self.ngroups, dtype=np.int64)
2430-
mask = isna(vals).view(np.uint8)
2431-
2432-
func = partial(base_func, result, counts, vals, labels, -1, mask)
2433-
func(**kwargs) # Call func to modify indexer values in place
2434-
result = post_processing(result, inferences)
2435-
2436-
if self._obj_with_exclusions.ndim == 1:
2437-
key = base.OutputKey(label=self._obj_with_exclusions.name, position=0)
2438-
output[key] = result[:, 0]
2439-
else:
2440-
for idx, name in enumerate(self._obj_with_exclusions.columns):
2441-
key = base.OutputKey(label=name, position=idx)
2442-
output[key] = result[:, idx]
2443-
2444-
if aggregate:
2445-
return self._wrap_aggregated_output(output)
2446-
else:
2447-
return self._wrap_transformed_output(output)
2448-
24492429
for idx, obj in enumerate(self._iterate_slices()):
24502430
name = obj.name
24512431
values = obj._values
24522432

2433+
if numeric_only and not is_numeric_dtype(values):
2434+
continue
2435+
24532436
if aggregate:
24542437
result_sz = ngroups
24552438
else:
24562439
result_sz = len(values)
24572440

2458-
result = np.zeros(result_sz, dtype=cython_dtype)
2459-
func = partial(base_func, result, labels)
2441+
if needs_2d:
2442+
result = np.zeros((result_sz, 1), dtype=cython_dtype)
2443+
else:
2444+
result = np.zeros(result_sz, dtype=cython_dtype)
2445+
func = partial(base_func, result)
2446+
24602447
inferences = None
24612448

2449+
if needs_counts:
2450+
counts = np.zeros(self.ngroups, dtype=np.int64)
2451+
func = partial(func, counts)
2452+
24622453
if needs_values:
24632454
vals = values
24642455
if pre_processing:
24652456
vals, inferences = pre_processing(vals)
2457+
if needs_2d:
2458+
vals = vals.reshape((-1, 1))
2459+
vals = vals.astype(cython_dtype, copy=False)
24662460
func = partial(func, vals)
24672461

2462+
func = partial(func, labels)
2463+
2464+
if min_count is not None:
2465+
func = partial(func, min_count)
2466+
24682467
if needs_mask:
24692468
mask = isna(values).view(np.uint8)
24702469
func = partial(func, mask)
@@ -2474,6 +2473,9 @@ def _get_cythonized_result(
24742473

24752474
func(**kwargs) # Call func to modify indexer values in place
24762475

2476+
if needs_2d:
2477+
result = result.reshape(-1)
2478+
24772479
if result_is_index:
24782480
result = algorithms.take_nd(values, result)
24792481

@@ -2524,6 +2526,7 @@ def shift(self, periods=1, freq=None, axis=0, fill_value=None):
25242526

25252527
return self._get_cythonized_result(
25262528
"group_shift_indexer",
2529+
numeric_only=False,
25272530
cython_dtype=np.dtype(np.int64),
25282531
needs_ngroups=True,
25292532
result_is_index=True,

0 commit comments

Comments
 (0)