@@ -714,10 +714,12 @@ group_ohlc_float64 = _group_ohlc['double']
714
714
715
715
@ cython.boundscheck (False )
716
716
@ cython.wraparound (False )
717
- def group_quantile (ndarray[float64_t] out ,
718
- ndarray[int64_t] labels ,
719
- numeric[:] values ,
720
- ndarray[uint8_t] mask ,
717
+ def group_quantile (floating[:, :] out ,
718
+ int64_t[:] counts ,
719
+ floating[:, :] values ,
720
+ const int64_t[:] labels ,
721
+ Py_ssize_t min_count ,
722
+ const uint8_t[:, :] mask ,
721
723
float64_t q ,
722
724
object interpolation ):
723
725
"""
@@ -740,12 +742,12 @@ def group_quantile(ndarray[float64_t] out,
740
742
provided `out` parameter.
741
743
"""
742
744
cdef:
743
- Py_ssize_t i, N= len (labels), ngroups, grp_sz, non_na_sz
745
+ Py_ssize_t i, N= len (labels), K, ngroups, grp_sz= 0 , non_na_sz
744
746
Py_ssize_t grp_start= 0 , idx= 0
745
747
int64_t lab
746
748
uint8_t interp
747
749
float64_t q_idx, frac, val, next_val
748
- ndarray[ int64_t] counts, non_na_counts, sort_arr
750
+ int64_t[:, :] non_na_counts, sort_arrs
749
751
750
752
assert values.shape[0 ] == N
751
753
@@ -761,59 +763,64 @@ def group_quantile(ndarray[float64_t] out,
761
763
}
762
764
interp = inter_methods[interpolation]
763
765
764
- counts = np.zeros_like(out, dtype = np.int64)
765
766
non_na_counts = np.zeros_like(out, dtype = np.int64)
767
+ sort_arrs = np.empty_like(values, dtype = np.int64)
766
768
ngroups = len (counts)
767
769
770
+ N, K = (< object > values).shape
771
+
768
772
# First figure out the size of every group
769
773
with nogil:
770
774
for i in range (N):
771
775
lab = labels[i]
772
776
if lab == - 1 : # NA group label
773
777
continue
774
-
775
778
counts[lab] += 1
776
- if not mask[i]:
777
- non_na_counts[lab] += 1
779
+ for j in range (K):
780
+ if not mask[i, j]:
781
+ non_na_counts[lab, j] += 1
778
782
779
- # Get an index of values sorted by labels and then values
780
- order = (values, labels)
781
- sort_arr = np.lexsort(order).astype(np.int64, copy = False )
783
+ for j in range (K):
784
+ order = (values[:, j], labels)
785
+ r = np.lexsort(order).astype(np.int64, copy = False )
786
+ # TODO: Need better way to assign r to column j
787
+ for i in range (N):
788
+ sort_arrs[i, j] = r[i]
782
789
783
790
with nogil:
784
791
for i in range (ngroups):
785
792
# Figure out how many group elements there are
786
793
grp_sz = counts[i]
787
- non_na_sz = non_na_counts[i]
788
-
789
- if non_na_sz == 0 :
790
- out[i] = NaN
791
- else :
792
- # Calculate where to retrieve the desired value
793
- # Casting to int will intentionally truncate result
794
- idx = grp_start + < int64_t> (q * < float64_t> (non_na_sz - 1 ))
795
-
796
- val = values[sort_arr[idx]]
797
- # If requested quantile falls evenly on a particular index
798
- # then write that index's value out. Otherwise interpolate
799
- q_idx = q * (non_na_sz - 1 )
800
- frac = q_idx % 1
801
-
802
- if frac == 0.0 or interp == INTERPOLATION_LOWER:
803
- out[i] = val
794
+ for j in range (K):
795
+ non_na_sz = non_na_counts[i, j]
796
+ if non_na_sz == 0 :
797
+ out[i, j] = NaN
804
798
else :
805
- next_val = values[sort_arr[idx + 1 ]]
806
- if interp == INTERPOLATION_LINEAR:
807
- out[i] = val + (next_val - val) * frac
808
- elif interp == INTERPOLATION_HIGHER:
809
- out[i] = next_val
810
- elif interp == INTERPOLATION_MIDPOINT:
811
- out[i] = (val + next_val) / 2.0
812
- elif interp == INTERPOLATION_NEAREST:
813
- if frac > .5 or (frac == .5 and q > .5 ): # Always OK?
814
- out[i] = next_val
815
- else :
816
- out[i] = val
799
+ # Calculate where to retrieve the desired value
800
+ # Casting to int will intentionally truncate result
801
+ idx = grp_start + < int64_t> (q * < float64_t> (non_na_sz - 1 ))
802
+
803
+ val = values[sort_arrs[idx, j], j]
804
+ # If requested quantile falls evenly on a particular index
805
+ # then write that index's value out. Otherwise interpolate
806
+ q_idx = q * (non_na_sz - 1 )
807
+ frac = q_idx % 1
808
+
809
+ if frac == 0.0 or interp == INTERPOLATION_LOWER:
810
+ out[i, j] = val
811
+ else :
812
+ next_val = values[sort_arrs[idx + 1 , j], j]
813
+ if interp == INTERPOLATION_LINEAR:
814
+ out[i, j] = val + (next_val - val) * frac
815
+ elif interp == INTERPOLATION_HIGHER:
816
+ out[i, j] = next_val
817
+ elif interp == INTERPOLATION_MIDPOINT:
818
+ out[i, j] = (val + next_val) / 2.0
819
+ elif interp == INTERPOLATION_NEAREST:
820
+ if frac > .5 or (frac == .5 and q > .5 ): # Always OK?
821
+ out[i, j] = next_val
822
+ else :
823
+ out[i, j] = val
817
824
818
825
# Increment the index reference in sorted_arr for the next group
819
826
grp_start += grp_sz
0 commit comments