Skip to content

CLN: Unify signatures in _libs.groupby #34372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 18, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,12 @@ group_ohlc_float64 = _group_ohlc['double']

@cython.boundscheck(False)
@cython.wraparound(False)
def group_quantile(ndarray[float64_t] out,
ndarray[int64_t] labels,
numeric[:] values,
ndarray[uint8_t] mask,
def group_quantile(floating[:, :] out,
int64_t[:] counts,
floating[:, :] values,
const int64_t[:] labels,
Py_ssize_t min_count,
const uint8_t[:, :] mask,
float64_t q,
object interpolation):
"""
Expand All @@ -740,12 +742,12 @@ def group_quantile(ndarray[float64_t] out,
provided `out` parameter.
"""
cdef:
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
Py_ssize_t i, N=len(labels), K, ngroups, grp_sz=0, non_na_sz
Py_ssize_t grp_start=0, idx=0
int64_t lab
uint8_t interp
float64_t q_idx, frac, val, next_val
ndarray[int64_t] counts, non_na_counts, sort_arr
int64_t[:, :] non_na_counts, sort_arrs

assert values.shape[0] == N

Expand All @@ -761,59 +763,64 @@ def group_quantile(ndarray[float64_t] out,
}
interp = inter_methods[interpolation]

counts = np.zeros_like(out, dtype=np.int64)
non_na_counts = np.zeros_like(out, dtype=np.int64)
sort_arrs = np.empty_like(values, dtype=np.int64)
ngroups = len(counts)

N, K = (<object>values).shape

# First figure out the size of every group
with nogil:
for i in range(N):
lab = labels[i]
if lab == -1: # NA group label
continue

counts[lab] += 1
if not mask[i]:
non_na_counts[lab] += 1
for j in range(K):
if not mask[i, j]:
non_na_counts[lab, j] += 1

# Get an index of values sorted by labels and then values
order = (values, labels)
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
for j in range(K):
order = (values[:, j], labels)
r = np.lexsort(order).astype(np.int64, copy=False)
# TODO: Need better way to assign r to column j
for i in range(N):
sort_arrs[i, j] = r[i]

with nogil:
for i in range(ngroups):
# Figure out how many group elements there are
grp_sz = counts[i]
non_na_sz = non_na_counts[i]

if non_na_sz == 0:
out[i] = NaN
else:
# Calculate where to retrieve the desired value
# Casting to int will intentionally truncate result
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))

val = values[sort_arr[idx]]
# If requested quantile falls evenly on a particular index
# then write that index's value out. Otherwise interpolate
q_idx = q * (non_na_sz - 1)
frac = q_idx % 1

if frac == 0.0 or interp == INTERPOLATION_LOWER:
out[i] = val
for j in range(K):
non_na_sz = non_na_counts[i, j]
if non_na_sz == 0:
out[i, j] = NaN
else:
next_val = values[sort_arr[idx + 1]]
if interp == INTERPOLATION_LINEAR:
out[i] = val + (next_val - val) * frac
elif interp == INTERPOLATION_HIGHER:
out[i] = next_val
elif interp == INTERPOLATION_MIDPOINT:
out[i] = (val + next_val) / 2.0
elif interp == INTERPOLATION_NEAREST:
if frac > .5 or (frac == .5 and q > .5): # Always OK?
out[i] = next_val
else:
out[i] = val
# Calculate where to retrieve the desired value
# Casting to int will intentionally truncate result
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))

val = values[sort_arrs[idx, j], j]
# If requested quantile falls evenly on a particular index
# then write that index's value out. Otherwise interpolate
q_idx = q * (non_na_sz - 1)
frac = q_idx % 1

if frac == 0.0 or interp == INTERPOLATION_LOWER:
out[i, j] = val
else:
next_val = values[sort_arrs[idx + 1, j], j]
if interp == INTERPOLATION_LINEAR:
out[i, j] = val + (next_val - val) * frac
elif interp == INTERPOLATION_HIGHER:
out[i, j] = next_val
elif interp == INTERPOLATION_MIDPOINT:
out[i, j] = (val + next_val) / 2.0
elif interp == INTERPOLATION_NEAREST:
if frac > .5 or (frac == .5 and q > .5): # Always OK?
out[i, j] = next_val
else:
out[i, j] = val

# Increment the index reference in sorted_arr for the next group
grp_start += grp_sz
Expand Down
36 changes: 35 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,6 +2039,9 @@ def pre_processor(vals: np.ndarray) -> Tuple[np.ndarray, Optional[Type]]:
inference = "datetime64[ns]"
vals = np.asarray(vals).astype(np.float)

if vals.dtype != np.dtype(np.float64):
vals = vals.astype(np.float64)

return vals, inference

def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
Expand Down Expand Up @@ -2396,7 +2399,7 @@ def _get_cythonized_result(
if result_is_index and aggregate:
raise ValueError("'result_is_index' and 'aggregate' cannot both be True!")
if post_processing:
if not callable(pre_processing):
if not callable(post_processing):
raise ValueError("'post_processing' must be a callable!")
if pre_processing:
if not callable(pre_processing):
Expand All @@ -2412,6 +2415,37 @@ def _get_cythonized_result(
output: Dict[base.OutputKey, np.ndarray] = {}
base_func = getattr(libgroupby, how)

if how == "group_quantile":
values = self._obj_with_exclusions._values
result_sz = ngroups if aggregate else len(values)

vals, inferences = pre_processing(values)
if self._obj_with_exclusions.ndim == 1:
width = 1
vals = np.reshape(vals, (-1, 1))
else:
width = len(self._obj_with_exclusions.columns)
result = np.zeros((result_sz, width), dtype=cython_dtype)
counts = np.zeros(self.ngroups, dtype=np.int64)
mask = isna(vals).view(np.uint8)

func = partial(base_func, result, counts, vals, labels, -1, mask)
func(**kwargs) # Call func to modify indexer values in place
result = post_processing(result, inferences)

if self._obj_with_exclusions.ndim == 1:
key = base.OutputKey(label=self._obj_with_exclusions.name, position=0)
output[key] = result[:, 0]
else:
for idx, name in enumerate(self._obj_with_exclusions.columns):
key = base.OutputKey(label=name, position=idx)
output[key] = result[:, idx]

if aggregate:
return self._wrap_aggregated_output(output)
else:
return self._wrap_transformed_output(output)

for idx, obj in enumerate(self._iterate_slices()):
name = obj.name
values = obj._values
Expand Down