Skip to content

Commit 091e73d

Browse files
authored
Remove one more cast to numpy (#434)
1 parent 8dac463 commit 091e73d

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

flox/core.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TypedDict,
2525
TypeVar,
2626
Union,
27+
cast,
2728
overload,
2829
)
2930

@@ -843,7 +844,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
843844
return offset, size
844845

845846

846-
def _factorize_single(by, expect, *, sort: bool, reindex: bool):
847+
def _factorize_single(by, expect, *, sort: bool, reindex: bool) -> tuple[pd.Index, np.ndarray]:
847848
flat = by.reshape(-1)
848849
if isinstance(expect, pd.RangeIndex):
849850
# idx is a view of the original `by` array
@@ -852,7 +853,7 @@ def _factorize_single(by, expect, *, sort: bool, reindex: bool):
852853
# this is important in shared-memory parallelism with dask
853854
# TODO: figure out how to avoid this
854855
idx = flat.copy()
855-
found_groups = np.array(expect)
856+
found_groups = cast(pd.Index, expect)
856857
# TODO: fix by using masked integers
857858
idx[idx > expect[-1]] = -1
858859

@@ -875,7 +876,7 @@ def _factorize_single(by, expect, *, sort: bool, reindex: bool):
875876
idx[~within_bins] = -1
876877
else:
877878
idx = np.zeros_like(flat, dtype=np.intp) - 1
878-
found_groups = np.array(expect)
879+
found_groups = cast(pd.Index, expect)
879880
else:
880881
if expect is not None and reindex:
881882
sorter = np.argsort(expect)
@@ -890,7 +891,7 @@ def _factorize_single(by, expect, *, sort: bool, reindex: bool):
890891
idx[mask] = -1
891892
else:
892893
idx, groups = pd.factorize(flat, sort=sort)
893-
found_groups = np.array(groups)
894+
found_groups = cast(pd.Index, groups)
894895

895896
return (found_groups, idx.reshape(by.shape))
896897

@@ -913,7 +914,7 @@ def factorize_(
913914
expected_groups: T_ExpectIndexOptTuple | None = None,
914915
reindex: bool = False,
915916
sort: bool = True,
916-
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]: ...
917+
) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, None]: ...
917918

918919

919920
@overload
@@ -925,7 +926,7 @@ def factorize_(
925926
reindex: bool = False,
926927
sort: bool = True,
927928
fastpath: Literal[False] = False,
928-
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps]: ...
929+
) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps]: ...
929930

930931

931932
@overload
@@ -937,7 +938,7 @@ def factorize_(
937938
reindex: bool = False,
938939
sort: bool = True,
939940
fastpath: bool = False,
940-
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]: ...
941+
) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]: ...
941942

942943

943944
def factorize_(
@@ -948,7 +949,7 @@ def factorize_(
948949
reindex: bool = False,
949950
sort: bool = True,
950951
fastpath: bool = False,
951-
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]:
952+
) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]:
952953
"""
953954
Returns an array of integer codes for groups (and associated data)
954955
by wrapping pd.cut and pd.factorize (depending on isbin).
@@ -971,7 +972,7 @@ def factorize_(
971972
_factorize_single(groupvar, expect, sort=sort, reindex=reindex)
972973
for groupvar, expect in zip(by, expected_groups)
973974
)
974-
found_groups = [r[0] for r in results]
975+
found_groups = tuple(r[0] for r in results)
975976
factorized = [r[1] for r in results]
976977

977978
grp_shape = tuple(len(grp) for grp in found_groups)
@@ -982,7 +983,7 @@ def factorize_(
982983
(group_idx,) = factorized
983984

984985
if fastpath:
985-
return group_idx, tuple(found_groups), grp_shape, ngroups, ngroups, None
986+
return group_idx, found_groups, grp_shape, ngroups, ngroups, None
986987

987988
if len(axes) == 1 and by[0].ndim > 1:
988989
# Not reducing along all dimensions of by
@@ -1178,7 +1179,7 @@ def chunk_reduce(
11781179
results: IntermediateDict = {"groups": [], "intermediates": []}
11791180
if reindex and expected_groups is not None:
11801181
# TODO: what happens with binning here?
1181-
results["groups"] = expected_groups.to_numpy()
1182+
results["groups"] = expected_groups
11821183
else:
11831184
if empty:
11841185
results["groups"] = np.array([np.nan])
@@ -1307,7 +1308,7 @@ def _finalize_results(
13071308
fill_value=fill_value,
13081309
array_type=reindex.array_type,
13091310
)
1310-
finalized["groups"] = expected_groups.to_numpy()
1311+
finalized["groups"] = expected_groups
13111312
else:
13121313
finalized["groups"] = squeezed["groups"]
13131314

@@ -2272,7 +2273,7 @@ def _factorize_multiple(
22722273
expected_groups: T_ExpectIndexOptTuple,
22732274
any_by_dask: bool,
22742275
sort: bool = True,
2275-
) -> tuple[tuple[np.ndarray], tuple[np.ndarray, ...], tuple[int, ...]]:
2276+
) -> tuple[tuple[np.ndarray], tuple[pd.Index, ...], tuple[int, ...]]:
22762277
kwargs: FactorizeKwargs = dict(
22772278
axes=(), # always (), we offset later if necessary.
22782279
fastpath=True,
@@ -2293,7 +2294,7 @@ def _factorize_multiple(
22932294
raise ValueError("Please provide expected_groups when grouping by a dask array.")
22942295

22952296
found_groups = tuple(
2296-
pd.unique(by_.reshape(-1)) if expect is None else expect.to_numpy()
2297+
pd.Index(pd.unique(by_.reshape(-1))) if expect is None else expect
22972298
for by_, expect in zip(by, expected_groups)
22982299
)
22992300
grp_shape = tuple(map(len, found_groups))
@@ -2883,6 +2884,9 @@ def groupby_reduce(
28832884
result = asdelta + offset
28842885
result[nanmask] = np.timedelta64("NaT")
28852886

2887+
groups = map(
2888+
lambda g: g.to_numpy() if isinstance(g, pd.Index) and not isinstance(g, pd.RangeIndex) else g, groups
2889+
)
28862890
return (result, *groups)
28872891

28882892

0 commit comments

Comments
 (0)