Skip to content

Commit 9b3c301

Browse files
authored
TYP: misc return types (#57285)
1 parent 99e3afe commit 9b3c301

File tree

18 files changed

+67
-37
lines changed

18 files changed

+67
-37
lines changed

pandas/_typing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@
4646

4747
from pandas.core.dtypes.dtypes import ExtensionDtype
4848

49-
from pandas import Interval
49+
from pandas import (
50+
DatetimeIndex,
51+
Interval,
52+
PeriodIndex,
53+
TimedeltaIndex,
54+
)
5055
from pandas.arrays import (
5156
DatetimeArray,
5257
TimedeltaArray,
@@ -190,6 +195,7 @@ def __reversed__(self) -> Iterator[_T_co]:
190195
NDFrameT = TypeVar("NDFrameT", bound="NDFrame")
191196

192197
IndexT = TypeVar("IndexT", bound="Index")
198+
FreqIndexT = TypeVar("FreqIndexT", "DatetimeIndex", "PeriodIndex", "TimedeltaIndex")
193199
NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")
194200

195201
AxisInt = int

pandas/core/_numba/extensions.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from contextlib import contextmanager
1414
import operator
15+
from typing import TYPE_CHECKING
1516

1617
import numba
1718
from numba import types
@@ -40,6 +41,9 @@
4041
from pandas.core.internals import SingleBlockManager
4142
from pandas.core.series import Series
4243

44+
if TYPE_CHECKING:
45+
from pandas._typing import Self
46+
4347

4448
# Helper function to hack around fact that Index casts numpy string dtype to object
4549
#
@@ -84,7 +88,7 @@ def key(self):
8488
def as_array(self):
8589
return types.Array(self.dtype, 1, self.layout)
8690

87-
def copy(self, dtype=None, ndim: int = 1, layout=None):
91+
def copy(self, dtype=None, ndim: int = 1, layout=None) -> Self:
8892
assert ndim == 1
8993
if dtype is None:
9094
dtype = self.dtype
@@ -114,7 +118,7 @@ def key(self):
114118
def as_array(self):
115119
return self.values
116120

117-
def copy(self, dtype=None, ndim: int = 1, layout: str = "C"):
121+
def copy(self, dtype=None, ndim: int = 1, layout: str = "C") -> Self:
118122
assert ndim == 1
119123
assert layout == "C"
120124
if dtype is None:
@@ -123,7 +127,7 @@ def copy(self, dtype=None, ndim: int = 1, layout: str = "C"):
123127

124128

125129
@typeof_impl.register(Index)
126-
def typeof_index(val, c):
130+
def typeof_index(val, c) -> IndexType:
127131
"""
128132
This will assume that only strings are in object dtype
129133
index.
@@ -136,7 +140,7 @@ def typeof_index(val, c):
136140

137141

138142
@typeof_impl.register(Series)
139-
def typeof_series(val, c):
143+
def typeof_series(val, c) -> SeriesType:
140144
index = typeof_impl(val.index, c)
141145
arrty = typeof_impl(val.values, c)
142146
namety = typeof_impl(val.name, c)
@@ -532,7 +536,7 @@ def key(self):
532536

533537

534538
@typeof_impl.register(_iLocIndexer)
535-
def typeof_iloc(val, c):
539+
def typeof_iloc(val, c) -> IlocType:
536540
objtype = typeof_impl(val.obj, c)
537541
return IlocType(objtype)
538542

pandas/core/arrays/_mixins.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def searchsorted(
250250
return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)
251251

252252
@doc(ExtensionArray.shift)
253-
def shift(self, periods: int = 1, fill_value=None):
253+
def shift(self, periods: int = 1, fill_value=None) -> Self:
254254
# NB: shift is always along axis=0
255255
axis = 0
256256
fill_value = self._validate_scalar(fill_value)

pandas/core/arrays/arrow/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def floordiv_compat(
204204
from pandas.core.arrays.timedeltas import TimedeltaArray
205205

206206

207-
def get_unit_from_pa_dtype(pa_dtype):
207+
def get_unit_from_pa_dtype(pa_dtype) -> str:
208208
# https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804
209209
if pa_version_under11p0:
210210
unit = str(pa_dtype).split("[", 1)[-1][:-1]
@@ -1966,7 +1966,7 @@ def _rank(
19661966
na_option: str = "keep",
19671967
ascending: bool = True,
19681968
pct: bool = False,
1969-
):
1969+
) -> Self:
19701970
"""
19711971
See Series.rank.__doc__.
19721972
"""

pandas/core/dtypes/dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2337,7 +2337,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
23372337
except NotImplementedError:
23382338
return None
23392339

2340-
def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
2340+
def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ArrowExtensionArray:
23412341
"""
23422342
Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray.
23432343
"""

pandas/core/dtypes/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
# define abstract base classes to enable isinstance type checking on our
3535
# objects
36-
def create_pandas_abc_type(name, attr, comp):
36+
def create_pandas_abc_type(name, attr, comp) -> type:
3737
def _check(inst) -> bool:
3838
return getattr(inst, attr, "_typ") in comp
3939

pandas/core/groupby/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1788,7 +1788,7 @@ def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFram
17881788

17891789
return path, res
17901790

1791-
def filter(self, func, dropna: bool = True, *args, **kwargs):
1791+
def filter(self, func, dropna: bool = True, *args, **kwargs) -> DataFrame:
17921792
"""
17931793
Filter elements from groups that don't satisfy a criterion.
17941794

pandas/core/indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ def _get_slice_axis(self, slice_obj: slice, axis: AxisInt):
17501750
labels._validate_positional_slice(slice_obj)
17511751
return self.obj._slice(slice_obj, axis=axis)
17521752

1753-
def _convert_to_indexer(self, key, axis: AxisInt):
1753+
def _convert_to_indexer(self, key: T, axis: AxisInt) -> T:
17541754
"""
17551755
Much simpler as we only have to deal with our valid types.
17561756
"""

pandas/core/interchange/from_dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def from_dataframe(df, allow_copy: bool = True) -> pd.DataFrame:
7575
)
7676

7777

78-
def _from_dataframe(df: DataFrameXchg, allow_copy: bool = True):
78+
def _from_dataframe(df: DataFrameXchg, allow_copy: bool = True) -> pd.DataFrame:
7979
"""
8080
Build a ``pd.DataFrame`` from the DataFrame interchange object.
8181

pandas/core/internals/blocks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2537,7 +2537,7 @@ def get_block_type(dtype: DtypeObj) -> type[Block]:
25372537

25382538
def new_block_2d(
25392539
values: ArrayLike, placement: BlockPlacement, refs: BlockValuesRefs | None = None
2540-
):
2540+
) -> Block:
25412541
# new_block specialized to case with
25422542
# ndim=2
25432543
# isinstance(placement, BlockPlacement)

pandas/core/internals/managers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Any,
1111
Callable,
1212
Literal,
13+
NoReturn,
1314
cast,
1415
final,
1516
)
@@ -2349,7 +2350,7 @@ def raise_construction_error(
23492350
block_shape: Shape,
23502351
axes: list[Index],
23512352
e: ValueError | None = None,
2352-
):
2353+
) -> NoReturn:
23532354
"""raise a helpful message about our construction"""
23542355
passed = tuple(map(int, [tot_items] + list(block_shape)))
23552356
# Correcting the user facing error message during dataframe construction

pandas/core/ops/array_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def comp_method_OBJECT_ARRAY(op, x, y):
130130
return result.reshape(x.shape)
131131

132132

133-
def _masked_arith_op(x: np.ndarray, y, op):
133+
def _masked_arith_op(x: np.ndarray, y, op) -> np.ndarray:
134134
"""
135135
If the given arithmetic operation fails, attempt it again on
136136
only the non-null elements of the input array(s).

pandas/core/ops/mask_ops.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,25 @@
33
"""
44
from __future__ import annotations
55

6+
from typing import TYPE_CHECKING
7+
68
import numpy as np
79

810
from pandas._libs import (
911
lib,
1012
missing as libmissing,
1113
)
1214

15+
if TYPE_CHECKING:
16+
from pandas._typing import npt
17+
1318

1419
def kleene_or(
1520
left: bool | np.ndarray | libmissing.NAType,
1621
right: bool | np.ndarray | libmissing.NAType,
1722
left_mask: np.ndarray | None,
1823
right_mask: np.ndarray | None,
19-
):
24+
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
2025
"""
2126
Boolean ``or`` using Kleene logic.
2227
@@ -78,7 +83,7 @@ def kleene_xor(
7883
right: bool | np.ndarray | libmissing.NAType,
7984
left_mask: np.ndarray | None,
8085
right_mask: np.ndarray | None,
81-
):
86+
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
8287
"""
8388
Boolean ``xor`` using Kleene logic.
8489
@@ -131,7 +136,7 @@ def kleene_and(
131136
right: bool | libmissing.NAType | np.ndarray,
132137
left_mask: np.ndarray | None,
133138
right_mask: np.ndarray | None,
134-
):
139+
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
135140
"""
136141
Boolean ``and`` using Kleene logic.
137142

pandas/core/resample.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
AnyArrayLike,
102102
Axis,
103103
Concatenate,
104+
FreqIndexT,
104105
Frequency,
105106
IndexLabel,
106107
InterpolateOptions,
@@ -1690,7 +1691,7 @@ class DatetimeIndexResampler(Resampler):
16901691
ax: DatetimeIndex
16911692

16921693
@property
1693-
def _resampler_for_grouping(self):
1694+
def _resampler_for_grouping(self) -> type[DatetimeIndexResamplerGroupby]:
16941695
return DatetimeIndexResamplerGroupby
16951696

16961697
def _get_binner_for_time(self):
@@ -2483,17 +2484,28 @@ def _set_grouper(
24832484
return obj, ax, indexer
24842485

24852486

2487+
@overload
24862488
def _take_new_index(
2487-
obj: NDFrameT,
2489+
obj: DataFrame, indexer: npt.NDArray[np.intp], new_index: Index
2490+
) -> DataFrame:
2491+
...
2492+
2493+
2494+
@overload
2495+
def _take_new_index(
2496+
obj: Series, indexer: npt.NDArray[np.intp], new_index: Index
2497+
) -> Series:
2498+
...
2499+
2500+
2501+
def _take_new_index(
2502+
obj: DataFrame | Series,
24882503
indexer: npt.NDArray[np.intp],
24892504
new_index: Index,
2490-
) -> NDFrameT:
2505+
) -> DataFrame | Series:
24912506
if isinstance(obj, ABCSeries):
24922507
new_values = algos.take_nd(obj._values, indexer)
2493-
# error: Incompatible return value type (got "Series", expected "NDFrameT")
2494-
return obj._constructor( # type: ignore[return-value]
2495-
new_values, index=new_index, name=obj.name
2496-
)
2508+
return obj._constructor(new_values, index=new_index, name=obj.name)
24972509
elif isinstance(obj, ABCDataFrame):
24982510
new_mgr = obj._mgr.reindex_indexer(new_axis=new_index, indexer=indexer, axis=1)
24992511
return obj._constructor_from_mgr(new_mgr, axes=new_mgr.axes)
@@ -2788,7 +2800,7 @@ def asfreq(
27882800
return new_obj
27892801

27902802

2791-
def _asfreq_compat(index: DatetimeIndex | PeriodIndex | TimedeltaIndex, freq):
2803+
def _asfreq_compat(index: FreqIndexT, freq) -> FreqIndexT:
27922804
"""
27932805
Helper to mimic asfreq on (empty) DatetimeIndex and TimedeltaIndex.
27942806
@@ -2806,7 +2818,6 @@ def _asfreq_compat(index: DatetimeIndex | PeriodIndex | TimedeltaIndex, freq):
28062818
raise ValueError(
28072819
"Can only set arbitrary freq for empty DatetimeIndex or TimedeltaIndex"
28082820
)
2809-
new_index: Index
28102821
if isinstance(index, PeriodIndex):
28112822
new_index = index.asfreq(freq=freq)
28122823
elif isinstance(index, DatetimeIndex):

pandas/core/reshape/pivot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def _normalize(
830830
return table
831831

832832

833-
def _get_names(arrs, names, prefix: str = "row"):
833+
def _get_names(arrs, names, prefix: str = "row") -> list:
834834
if names is None:
835835
names = []
836836
for i, arr in enumerate(arrs):

pandas/core/reshape/tile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def _format_labels(
548548
precision: int,
549549
right: bool = True,
550550
include_lowest: bool = False,
551-
):
551+
) -> IntervalIndex:
552552
"""based on the dtype, return our labels"""
553553
closed: IntervalLeftRight = "right" if right else "left"
554554

pandas/core/strings/base.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from collections.abc import Sequence
1414
import re
1515

16-
from pandas._typing import Scalar
17-
18-
from pandas import Series
16+
from pandas._typing import (
17+
Scalar,
18+
Self,
19+
)
1920

2021

2122
class BaseStringArrayMethods(abc.ABC):
@@ -240,11 +241,11 @@ def _str_rstrip(self, to_strip=None):
240241
pass
241242

242243
@abc.abstractmethod
243-
def _str_removeprefix(self, prefix: str) -> Series:
244+
def _str_removeprefix(self, prefix: str) -> Self:
244245
pass
245246

246247
@abc.abstractmethod
247-
def _str_removesuffix(self, suffix: str) -> Series:
248+
def _str_removesuffix(self, suffix: str) -> Self:
248249
pass
249250

250251
@abc.abstractmethod

pandas/core/tools/datetimes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,9 @@ def to_datetime(
11341134
}
11351135

11361136

1137-
def _assemble_from_unit_mappings(arg, errors: DateTimeErrorChoices, utc: bool):
1137+
def _assemble_from_unit_mappings(
1138+
arg, errors: DateTimeErrorChoices, utc: bool
1139+
) -> Series:
11381140
"""
11391141
assemble the unit specified fields from the arg (DataFrame)
11401142
Return a Series for actual parsing

0 commit comments

Comments
 (0)