Skip to content

Commit 5dd1820

Browse files
authored
gh-624 : Added index parameter to to_dict and added axis argument to Series.add_suffix(), DataFrame.add_suffix(), Series.add_prefix() and DataFrame.add_prefix() (#638)
* added arguments * changed axis parameters * req changes and created a different overload and corrected index args * creating overloads * Update frame.pyi * corrected diff overload args * added the tests * corrected the tests * Update test_series.py
1 parent c03e23c commit 5dd1820

File tree

4 files changed

+99
-14
lines changed

4 files changed

+99
-14
lines changed

Diff for: pandas-stubs/core/frame.pyi

+41-12
Original file line numberDiff line numberDiff line change
@@ -268,32 +268,61 @@ class DataFrame(NDFrame, OpsMixin):
268268
@overload
269269
def to_dict(
270270
self,
271-
orient: Literal["dict", "list", "series", "split", "tight", "index"],
271+
orient: Literal["records"],
272272
into: Mapping | type[Mapping],
273+
index: Literal[True] = ...,
274+
) -> list[Mapping[Hashable, Any]]: ...
275+
@overload
276+
def to_dict(
277+
self,
278+
orient: Literal["records"],
279+
into: None = ...,
280+
index: Literal[True] = ...,
281+
) -> list[dict[Hashable, Any]]: ...
282+
@overload
283+
def to_dict(
284+
self,
285+
orient: Literal["dict", "list", "series", "index"],
286+
into: Mapping | type[Mapping],
287+
index: Literal[True] = ...,
273288
) -> Mapping[Hashable, Any]: ...
274289
@overload
275290
def to_dict(
276291
self,
277-
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
278-
*,
292+
orient: Literal["split", "tight"],
279293
into: Mapping | type[Mapping],
294+
index: bool = ...,
280295
) -> Mapping[Hashable, Any]: ...
281296
@overload
282297
def to_dict(
283298
self,
284-
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
285-
into: None = ...,
286-
) -> dict[Hashable, Any]: ...
299+
orient: Literal["dict", "list", "series", "index"] = ...,
300+
*,
301+
into: Mapping | type[Mapping],
302+
index: Literal[True] = ...,
303+
) -> Mapping[Hashable, Any]: ...
287304
@overload
288305
def to_dict(
289306
self,
290-
orient: Literal["records"],
307+
orient: Literal["split", "tight"] = ...,
308+
*,
291309
into: Mapping | type[Mapping],
292-
) -> list[Mapping[Hashable, Any]]: ...
310+
index: bool = ...,
311+
) -> Mapping[Hashable, Any]: ...
293312
@overload
294313
def to_dict(
295-
self, orient: Literal["records"], into: None = ...
296-
) -> list[dict[Hashable, Any]]: ...
314+
self,
315+
orient: Literal["dict", "list", "series", "index"] = ...,
316+
into: None = ...,
317+
index: Literal[True] = ...,
318+
) -> dict[Hashable, Any]: ...
319+
@overload
320+
def to_dict(
321+
self,
322+
orient: Literal["split", "tight"] = ...,
323+
into: None = ...,
324+
index: bool = ...,
325+
) -> dict[Hashable, Any]: ...
297326
def to_gbq(
298327
self,
299328
destination_table: str,
@@ -1400,8 +1429,8 @@ class DataFrame(NDFrame, OpsMixin):
14001429
level: Level | None = ...,
14011430
fill_value: float | None = ...,
14021431
) -> DataFrame: ...
1403-
def add_prefix(self, prefix: _str) -> DataFrame: ...
1404-
def add_suffix(self, suffix: _str) -> DataFrame: ...
1432+
def add_prefix(self, prefix: _str, axis: Axis | None = None) -> DataFrame: ...
1433+
def add_suffix(self, suffix: _str, axis: Axis | None = None) -> DataFrame: ...
14051434
@overload
14061435
def all(
14071436
self,

Diff for: pandas-stubs/core/series.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -1024,8 +1024,8 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
10241024
def pop(self, item: Hashable) -> S1: ...
10251025
def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ...
10261026
def __abs__(self) -> Series[S1]: ...
1027-
def add_prefix(self, prefix: _str) -> Series[S1]: ...
1028-
def add_suffix(self, suffix: _str) -> Series[S1]: ...
1027+
def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
1028+
def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
10291029
def reindex(
10301030
self,
10311031
index: Axes | None = ...,

Diff for: tests/test_frame.py

+44
Original file line numberDiff line numberDiff line change
@@ -2526,3 +2526,47 @@ def test_loc_returns_series() -> None:
25262526
df1 = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40])
25272527
df2 = df1.loc[10, :]
25282528
check(assert_type(df2, Union[pd.Series, pd.DataFrame]), pd.Series)
2529+
2530+
2531+
def test_to_dict_index() -> None:
2532+
df = pd.DataFrame({"a": [1, 2], "b": [9, 10]})
2533+
check(
2534+
assert_type(
2535+
df.to_dict(orient="records", index=True), List[Dict[Hashable, Any]]
2536+
),
2537+
list,
2538+
)
2539+
check(assert_type(df.to_dict(orient="dict", index=True), Dict[Hashable, Any]), dict)
2540+
check(
2541+
assert_type(df.to_dict(orient="series", index=True), Dict[Hashable, Any]), dict
2542+
)
2543+
check(
2544+
assert_type(df.to_dict(orient="index", index=True), Dict[Hashable, Any]), dict
2545+
)
2546+
check(
2547+
assert_type(df.to_dict(orient="split", index=True), Dict[Hashable, Any]), dict
2548+
)
2549+
check(
2550+
assert_type(df.to_dict(orient="tight", index=True), Dict[Hashable, Any]), dict
2551+
)
2552+
check(
2553+
assert_type(df.to_dict(orient="tight", index=False), Dict[Hashable, Any]), dict
2554+
)
2555+
check(
2556+
assert_type(df.to_dict(orient="split", index=False), Dict[Hashable, Any]), dict
2557+
)
2558+
if TYPE_CHECKING_INVALID_USAGE:
2559+
check(assert_type(df.to_dict(orient="records", index=False), List[Dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
2560+
check(assert_type(df.to_dict(orient="dict", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
2561+
check(assert_type(df.to_dict(orient="series", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
2562+
check(assert_type(df.to_dict(orient="index", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
2563+
2564+
2565+
def test_suffix_prefix_index() -> None:
2566+
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]})
2567+
check(assert_type(df.add_suffix("_col", axis=1), pd.DataFrame), pd.DataFrame)
2568+
check(assert_type(df.add_suffix("_col", axis="index"), pd.DataFrame), pd.DataFrame)
2569+
check(assert_type(df.add_prefix("_col", axis="index"), pd.DataFrame), pd.DataFrame)
2570+
check(
2571+
assert_type(df.add_prefix("_col", axis="columns"), pd.DataFrame), pd.DataFrame
2572+
)

Diff for: tests/test_series.py

+12
Original file line numberDiff line numberDiff line change
@@ -1813,3 +1813,15 @@ def test_types_apply_set() -> None:
18131813
{"list1": [1, 2, 3], "list2": ["a", "b", "c"], "list3": [True, False, True]}
18141814
)
18151815
check(assert_type(series_of_lists.apply(lambda x: set(x)), pd.Series), pd.Series)
1816+
1817+
1818+
def test_prefix_summix_axis() -> None:
1819+
s = pd.Series([1, 2, 3, 4])
1820+
check(assert_type(s.add_suffix("_item", axis=0), pd.Series), pd.Series)
1821+
check(assert_type(s.add_suffix("_item", axis="index"), pd.Series), pd.Series)
1822+
check(assert_type(s.add_prefix("_item", axis=0), pd.Series), pd.Series)
1823+
check(assert_type(s.add_prefix("_item", axis="index"), pd.Series), pd.Series)
1824+
1825+
if TYPE_CHECKING_INVALID_USAGE:
1826+
check(assert_type(s.add_prefix("_item", axis=1), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
1827+
check(assert_type(s.add_suffix("_item", axis="columns"), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]

0 commit comments

Comments
 (0)