Skip to content

Commit 991fc72

Browse files
ENH: added df/series.sort_values(key=...) and df/series.sort_index(key=...) functionality
1 parent eddd9f0 commit 991fc72

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

pandas/core/frame.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4731,6 +4731,7 @@ def sort_values(
47314731
inplace=False,
47324732
kind="quicksort",
47334733
na_position="last",
4734+
key = None
47344735
):
47354736
inplace = validate_bool_kwarg(inplace, "inplace")
47364737
axis = self._get_axis_number(axis)
@@ -4744,7 +4745,12 @@ def sort_values(
47444745
if len(by) > 1:
47454746
from pandas.core.sorting import lexsort_indexer
47464747

4747-
keys = [self._get_label_or_level_values(x, axis=axis) for x in by]
4748+
if key is not None:
4749+
key_func = np.vectorize(key)
4750+
keys = [key_func(self._get_label_or_level_values(x, axis=axis)) for x in by]
4751+
else:
4752+
keys = [self._get_label_or_level_values(x, axis=axis) for x in by]
4753+
47484754
indexer = lexsort_indexer(keys, orders=ascending, na_position=na_position)
47494755
indexer = ensure_platform_int(indexer)
47504756
else:
@@ -4753,6 +4759,10 @@ def sort_values(
47534759
by = by[0]
47544760
k = self._get_label_or_level_values(by, axis=axis)
47554761

4762+
if key is not None:
4763+
key_func = np.vectorize(key)
4764+
k = key_func(k)
4765+
47564766
if isinstance(ascending, (tuple, list)):
47574767
ascending = ascending[0]
47584768

@@ -4781,6 +4791,7 @@ def sort_index(
47814791
na_position="last",
47824792
sort_remaining=True,
47834793
by=None,
4794+
key=None
47844795
):
47854796

47864797
# TODO: this can be combined with Series.sort_index impl as
@@ -4801,21 +4812,23 @@ def sort_index(
48014812

48024813
axis = self._get_axis_number(axis)
48034814
labels = self._get_axis(axis)
4804-
4815+
if key is not None:
4816+
labels = labels.map(key)
4817+
48054818
# make sure that the axis is lexsorted to start
48064819
# if not we need to reconstruct to get the correct indexer
48074820
labels = labels._sort_levels_monotonic()
48084821
if level is not None:
4809-
48104822
new_axis, indexer = labels.sortlevel(
48114823
level, ascending=ascending, sort_remaining=sort_remaining
48124824
)
48134825

48144826
elif isinstance(labels, ABCMultiIndex):
48154827
from pandas.core.sorting import lexsort_indexer
48164828

4829+
codes = labels._get_codes_for_sorting()
48174830
indexer = lexsort_indexer(
4818-
labels._get_codes_for_sorting(),
4831+
codes,
48194832
orders=ascending,
48204833
na_position=na_position,
48214834
)

pandas/core/series.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2890,6 +2890,7 @@ def sort_values(
28902890
inplace=False,
28912891
kind="quicksort",
28922892
na_position="last",
2893+
key=None
28932894
):
28942895
"""
28952896
Sort by the values.
@@ -3007,6 +3008,10 @@ def sort_values(
30073008
)
30083009

30093010
def _try_kind_sort(arr):
3011+
if key is not None:
3012+
key_func = np.vectorize(key)
3013+
arr = key_func(arr)
3014+
30103015
# easier to ask forgiveness than permission
30113016
try:
30123017
# if kind==mergesort, it can fail for object dtype
@@ -3066,6 +3071,7 @@ def sort_index(
30663071
kind="quicksort",
30673072
na_position="last",
30683073
sort_remaining=True,
3074+
key=None
30693075
):
30703076
"""
30713077
Sort Series by index labels.
@@ -3183,17 +3189,22 @@ def sort_index(
31833189
# Validate the axis parameter
31843190
self._get_axis_number(axis)
31853191
index = self.index
3186-
3192+
if key is not None:
3193+
index = index.map(key)
3194+
31873195
if level is not None:
31883196
new_index, indexer = index.sortlevel(
31893197
level, ascending=ascending, sort_remaining=sort_remaining
31903198
)
3199+
31913200
elif isinstance(index, MultiIndex):
31923201
from pandas.core.sorting import lexsort_indexer
31933202

31943203
labels = index._sort_levels_monotonic()
3204+
codes = labels._get_codes_for_sorting()
3205+
31953206
indexer = lexsort_indexer(
3196-
labels._get_codes_for_sorting(),
3207+
codes,
31973208
orders=ascending,
31983209
na_position=na_position,
31993210
)
@@ -3210,6 +3221,10 @@ def sort_index(
32103221
else:
32113222
return self.copy()
32123223

3224+
if key is not None:
3225+
key_func = np.vectorize(key)
3226+
index = key_func(index)
3227+
32133228
indexer = nargsort(
32143229
index, kind=kind, ascending=ascending, na_position=na_position
32153230
)

0 commit comments

Comments
 (0)