Skip to content

Commit 18f4fc9

Browse files
committed
use lexsort to sort values and mask, so that +/- inf values are placed before nan
1 parent cf2c967 commit 18f4fc9

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

pandas/_libs/algos_rank_helper.pxi.in

+15-19
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,24 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
4444

4545
{{if dtype == 'object'}}
4646
ndarray sorted_data, values
47-
ndarray[np.uint8_t, cast=True] sorted_namask
48-
{{elif dtype != 'uint64'}}
49-
ndarray[{{ctype}}] sorted_data, values
50-
ndarray[np.uint8_t, cast=True] sorted_namask
5147
{{else}}
5248
ndarray[{{ctype}}] sorted_data, values
5349
{{endif}}
5450

5551
ndarray[float64_t] ranks
5652
ndarray[int64_t] argsorted
53+
ndarray[np.uint8_t, cast=True] sorted_mask
5754

5855
{{if dtype == 'uint64'}}
5956
{{ctype}} val
6057
{{else}}
61-
{{ctype}} val, nan_value, isnan
58+
{{ctype}} val, nan_value
6259
{{endif}}
6360

6461
float64_t sum_ranks = 0
6562
int tiebreak = 0
6663
bint keep_na = 0
64+
bint isnan
6765
float count = 0.0
6866
tiebreak = tiebreakers[ties_method]
6967

@@ -95,14 +93,16 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
9593
{{endif}}
9694

9795
np.putmask(values, mask, nan_value)
96+
{{else}}
97+
mask = np.zeros(shape=len(values), dtype=bool)
9898
{{endif}}
9999

100100
n = len(values)
101101
ranks = np.empty(n, dtype='f8')
102102

103103
{{if dtype == 'object'}}
104104
try:
105-
_as = values.argsort()
105+
_as = np.lexsort(keys=(mask, values))
106106
except TypeError:
107107
if not retry:
108108
raise
@@ -116,40 +116,37 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
116116
{{else}}
117117
if tiebreak == TIEBREAK_FIRST:
118118
# need to use a stable sort here
119-
_as = values.argsort(kind='mergesort')
119+
_as = np.lexsort(keys=(mask, values))
120120
if not ascending:
121121
tiebreak = TIEBREAK_FIRST_DESCENDING
122122
else:
123-
_as = values.argsort()
123+
_as = np.lexsort(keys=(mask, values))
124124
{{endif}}
125125

126126
if not ascending:
127127
_as = _as[::-1]
128128

129129
sorted_data = values.take(_as)
130130
# need to distinguish between pos/neg nan and real nan when keep_na is true
131-
{{if dtype != 'uint64'}}
132-
sorted_namask = mask.take(_as)
133-
sorted_namask = sorted_namask.astype(np.bool)
134-
{{endif}}
131+
sorted_mask = mask.take(_as)
135132
argsorted = _as.astype('i8')
136133

137134
{{if dtype == 'object'}}
138135
for i in range(n):
139136
sum_ranks += i + 1
140137
dups += 1
141-
isnan = sorted_namask[i]
138+
isnan = sorted_mask[i]
142139
val = util.get_value_at(sorted_data, i)
143140

144141
if isnan and keep_na:
145142
ranks[argsorted[i]] = nan
146-
sum_ranks = dups = 0
147143
continue
148144

149145
count += 1.0
150146

151147
if (i == n - 1 or
152-
are_diff(util.get_value_at(sorted_data, i + 1), val)):
148+
are_diff(util.get_value_at(sorted_data, i + 1), val) or
149+
sorted_mask[i + 1]):
153150
if tiebreak == TIEBREAK_AVERAGE:
154151
for j in range(i - dups + 1, i + 1):
155152
ranks[argsorted[j]] = sum_ranks / dups
@@ -178,16 +175,15 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
178175
val = sorted_data[i]
179176

180177
{{if dtype != 'uint64'}}
181-
isnan = sorted_namask[i]
182-
if isnan and keep_na:
178+
isnan = sorted_mask[i]
179+
if isnan and keep_na:
183180
ranks[argsorted[i]] = nan
184-
sum_ranks = dups = 0
185181
continue
186182
{{endif}}
187183

188184
count += 1.0
189185

190-
if i == n - 1 or sorted_data[i + 1] != val:
186+
if i == n - 1 or sorted_data[i + 1] != val or sorted_mask[i + 1]:
191187
if tiebreak == TIEBREAK_AVERAGE:
192188
for j in range(i - dups + 1, i + 1):
193189
ranks[argsorted[j]] = sum_ranks / dups

0 commit comments

Comments
 (0)