Skip to content

Commit 53ae9d6

Browse files
committed
CLN: Unify signatures in _libs.groupby
1 parent 3fd150c commit 53ae9d6

File tree

2 files changed

+84
-43
lines changed

2 files changed

+84
-43
lines changed

pandas/_libs/groupby.pyx

+49-42
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,12 @@ group_ohlc_float64 = _group_ohlc['double']
714714

715715
@cython.boundscheck(False)
716716
@cython.wraparound(False)
717-
def group_quantile(ndarray[float64_t] out,
718-
ndarray[int64_t] labels,
719-
numeric[:] values,
720-
ndarray[uint8_t] mask,
717+
def group_quantile(floating[:, :] out,
718+
int64_t[:] counts,
719+
floating[:, :] values,
720+
const int64_t[:] labels,
721+
Py_ssize_t min_count,
722+
const uint8_t[:, :] mask,
721723
float64_t q,
722724
object interpolation):
723725
"""
@@ -740,12 +742,12 @@ def group_quantile(ndarray[float64_t] out,
740742
provided `out` parameter.
741743
"""
742744
cdef:
743-
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
745+
Py_ssize_t i, N=len(labels), K, ngroups, grp_sz=0, non_na_sz
744746
Py_ssize_t grp_start=0, idx=0
745747
int64_t lab
746748
uint8_t interp
747749
float64_t q_idx, frac, val, next_val
748-
ndarray[int64_t] counts, non_na_counts, sort_arr
750+
int64_t[:, :] non_na_counts, sort_arrs
749751

750752
assert values.shape[0] == N
751753

@@ -761,59 +763,64 @@ def group_quantile(ndarray[float64_t] out,
761763
}
762764
interp = inter_methods[interpolation]
763765

764-
counts = np.zeros_like(out, dtype=np.int64)
765766
non_na_counts = np.zeros_like(out, dtype=np.int64)
767+
sort_arrs = np.empty_like(values, dtype=np.int64)
766768
ngroups = len(counts)
767769

770+
N, K = (<object>values).shape
771+
768772
# First figure out the size of every group
769773
with nogil:
770774
for i in range(N):
771775
lab = labels[i]
772776
if lab == -1: # NA group label
773777
continue
774-
775778
counts[lab] += 1
776-
if not mask[i]:
777-
non_na_counts[lab] += 1
779+
for j in range(K):
780+
if not mask[i, j]:
781+
non_na_counts[lab, j] += 1
778782

779-
# Get an index of values sorted by labels and then values
780-
order = (values, labels)
781-
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
783+
for j in range(K):
784+
order = (values[:, j], labels)
785+
r = np.lexsort(order).astype(np.int64, copy=False)
786+
# TODO: Need better way to assign r to column j
787+
for i in range(N):
788+
sort_arrs[i, j] = r[i]
782789

783790
with nogil:
784791
for i in range(ngroups):
785792
# Figure out how many group elements there are
786793
grp_sz = counts[i]
787-
non_na_sz = non_na_counts[i]
788-
789-
if non_na_sz == 0:
790-
out[i] = NaN
791-
else:
792-
# Calculate where to retrieve the desired value
793-
# Casting to int will intentionally truncate result
794-
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
795-
796-
val = values[sort_arr[idx]]
797-
# If requested quantile falls evenly on a particular index
798-
# then write that index's value out. Otherwise interpolate
799-
q_idx = q * (non_na_sz - 1)
800-
frac = q_idx % 1
801-
802-
if frac == 0.0 or interp == INTERPOLATION_LOWER:
803-
out[i] = val
794+
for j in range(K):
795+
non_na_sz = non_na_counts[i, j]
796+
if non_na_sz == 0:
797+
out[i, j] = NaN
804798
else:
805-
next_val = values[sort_arr[idx + 1]]
806-
if interp == INTERPOLATION_LINEAR:
807-
out[i] = val + (next_val - val) * frac
808-
elif interp == INTERPOLATION_HIGHER:
809-
out[i] = next_val
810-
elif interp == INTERPOLATION_MIDPOINT:
811-
out[i] = (val + next_val) / 2.0
812-
elif interp == INTERPOLATION_NEAREST:
813-
if frac > .5 or (frac == .5 and q > .5): # Always OK?
814-
out[i] = next_val
815-
else:
816-
out[i] = val
799+
# Calculate where to retrieve the desired value
800+
# Casting to int will intentionally truncate result
801+
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
802+
803+
val = values[sort_arrs[idx, j], j]
804+
# If requested quantile falls evenly on a particular index
805+
# then write that index's value out. Otherwise interpolate
806+
q_idx = q * (non_na_sz - 1)
807+
frac = q_idx % 1
808+
809+
if frac == 0.0 or interp == INTERPOLATION_LOWER:
810+
out[i, j] = val
811+
else:
812+
next_val = values[sort_arrs[idx + 1, j], j]
813+
if interp == INTERPOLATION_LINEAR:
814+
out[i, j] = val + (next_val - val) * frac
815+
elif interp == INTERPOLATION_HIGHER:
816+
out[i, j] = next_val
817+
elif interp == INTERPOLATION_MIDPOINT:
818+
out[i, j] = (val + next_val) / 2.0
819+
elif interp == INTERPOLATION_NEAREST:
820+
if frac > .5 or (frac == .5 and q > .5): # Always OK?
821+
out[i, j] = next_val
822+
else:
823+
out[i, j] = val
817824

818825
# Increment the index reference in sorted_arr for the next group
819826
grp_start += grp_sz

pandas/core/groupby/groupby.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,9 @@ def pre_processor(vals: np.ndarray) -> Tuple[np.ndarray, Optional[Type]]:
20392039
inference = "datetime64[ns]"
20402040
vals = np.asarray(vals).astype(np.float)
20412041

2042+
if vals.dtype != np.dtype(np.float64):
2043+
vals = vals.astype(np.float64)
2044+
20422045
return vals, inference
20432046

20442047
def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
@@ -2396,7 +2399,7 @@ def _get_cythonized_result(
23962399
if result_is_index and aggregate:
23972400
raise ValueError("'result_is_index' and 'aggregate' cannot both be True!")
23982401
if post_processing:
2399-
if not callable(pre_processing):
2402+
if not callable(post_processing):
24002403
raise ValueError("'post_processing' must be a callable!")
24012404
if pre_processing:
24022405
if not callable(pre_processing):
@@ -2412,6 +2415,37 @@ def _get_cythonized_result(
24122415
output: Dict[base.OutputKey, np.ndarray] = {}
24132416
base_func = getattr(libgroupby, how)
24142417

2418+
if how == "group_quantile":
2419+
values = self._obj_with_exclusions._values
2420+
result_sz = ngroups if aggregate else len(values)
2421+
2422+
vals, inferences = pre_processing(values)
2423+
if self._obj_with_exclusions.ndim == 1:
2424+
width = 1
2425+
vals = np.reshape(vals, (-1, 1))
2426+
else:
2427+
width = len(self._obj_with_exclusions.columns)
2428+
result = np.zeros((result_sz, width), dtype=cython_dtype)
2429+
counts = np.zeros(self.ngroups, dtype=np.int64)
2430+
mask = isna(vals).view(np.uint8)
2431+
2432+
func = partial(base_func, result, counts, vals, labels, -1, mask)
2433+
func(**kwargs) # Call func to modify indexer values in place
2434+
result = post_processing(result, inferences)
2435+
2436+
if self._obj_with_exclusions.ndim == 1:
2437+
key = base.OutputKey(label=self._obj_with_exclusions.name, position=0)
2438+
output[key] = result[:, 0]
2439+
else:
2440+
for idx, name in enumerate(self._obj_with_exclusions.columns):
2441+
key = base.OutputKey(label=name, position=idx)
2442+
output[key] = result[:, idx]
2443+
2444+
if aggregate:
2445+
return self._wrap_aggregated_output(output)
2446+
else:
2447+
return self._wrap_transformed_output(output)
2448+
24152449
for idx, obj in enumerate(self._iterate_slices()):
24162450
name = obj.name
24172451
values = obj._values

0 commit comments

Comments
 (0)