Skip to content

Commit af069aa

Browse files
committed
typing
1 parent fd72963 commit af069aa

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

flox/aggregations.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -593,9 +593,6 @@ class Scan:
593593
mode: T_ScanBinaryOpMode = "apply_binary_op"
594594

595595

596-
HANDLED_FUNCTIONS = {}
597-
598-
599596
def concatenate(arrays: Sequence[AlignedArrays], axis=-1, out=None) -> AlignedArrays:
600597
group_idx = np.concatenate([a.group_idx for a in arrays], axis=axis)
601598
array = np.concatenate([a.array for a in arrays], axis=axis)
@@ -642,16 +639,16 @@ def __post_init__(self):
642639
assert (self.state is not None) or (self.result is not None)
643640

644641

645-
def scan_binary_op(
646-
left_state: AlignedArrays, right_state: AlignedArrays, *, agg: Scan
647-
) -> AlignedArrays:
642+
def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) -> ScanState:
648643
from .core import reindex_
649644

650645
assert left_state.state is not None
651646
left = left_state.state
652647
right = right_state.result if right_state.result is not None else right_state.state
648+
assert right is not None
653649

654650
if agg.mode == "apply_binary_op":
651+
assert agg.binary_op is not None
655652
# Implements groupby binary operation.
656653
reindexed = reindex_(
657654
left.array,
@@ -718,7 +715,7 @@ def scan_binary_op(
718715
# cumprod = Scan("cumprod", binary_op=np.multiply, preop="prod", scan="cumprod")
719716

720717

721-
AGGREGATIONS = {
718+
AGGREGATIONS: dict[str, Aggregation | Scan] = {
722719
"any": any_,
723720
"all": all_,
724721
"count": count,
@@ -769,7 +766,9 @@ def _initialize_aggregation(
769766
try:
770767
# TODO: need better interface
771768
# we set dtype, fillvalue on reduction later. so deepcopy now
772-
agg = copy.deepcopy(AGGREGATIONS[func])
769+
agg_ = copy.deepcopy(AGGREGATIONS[func])
770+
assert isinstance(agg_, Aggregation)
771+
agg = agg_
773772
except KeyError:
774773
raise NotImplementedError(f"Reduction {func!r} not implemented yet")
775774
elif isinstance(func, Aggregation):

flox/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def factorize_(
814814
idx = sorter[(idx,)]
815815
idx[mask] = -1
816816
else:
817-
idx, groups = pd.factorize(flat, sort=sort) # type: ignore[arg-type]
817+
idx, groups = pd.factorize(flat, sort=sort)
818818

819819
found_groups.append(np.array(groups))
820820
factorized.append(idx.reshape(groupvar.shape))
@@ -2717,7 +2717,7 @@ def groupby_scan(
27172717
agg.dtype = array.dtype
27182718

27192719
if not has_dask:
2720-
(single_axis,) = axis_
2720+
(single_axis,) = axis_ # type: ignore[misc]
27212721
final_state = chunk_scan(
27222722
AlignedArrays(array=array, group_idx=by_), axis=single_axis, agg=agg, dtype=agg.dtype
27232723
)
@@ -2765,6 +2765,7 @@ def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays:
27652765

27662766

27672767
def extract_array(block: ScanState) -> np.ndarray:
2768+
assert block.result is not None
27682769
return block.result.array
27692770

27702771

@@ -2787,7 +2788,7 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
27872788

27882789
scan_ = partial(chunk_scan, agg=agg)
27892790
# dask tokenizing error workaround
2790-
scan_.__name__ = scan_.func.__name__
2791+
scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined]
27912792

27922793
# 2. Run the scan
27932794
accumulated = scan(

tests/test_core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,3 +1833,9 @@ def test_nanlen_string(dtype, engine):
18331833
# func="cumsum",
18341834
# axis=-1,
18351835
# )
1836+
1837+
# numpy_array, group_idx = (
1838+
# array([1.6777218e07, 1.0000000e00, 0.0000000e00], dtype=float32),
1839+
# array([0, 1, 1]),
1840+
# )
1841+
# groupby_scan(numpy_array, group_idx, axis=-1, func="nancumsum")

0 commit comments

Comments
 (0)