diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index f1e1ebcaca1c4..dfb329b3ddbdf 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -387,7 +387,7 @@ def _verify_integrity( return new_codes @classmethod - def from_arrays(cls, arrays, sortorder=None, names=lib.no_default): + def from_arrays(cls, arrays, sortorder=None, names=lib.no_default) -> "MultiIndex": """ Convert arrays to MultiIndex. diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index c8d5eecf0e496..ea5916eff3afa 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -1,7 +1,18 @@ -from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) import numpy as np +from pandas._typing import Label from pandas.util._decorators import Appender, Substitution from pandas.core.dtypes.cast import maybe_downcast_to_dtype @@ -424,19 +435,22 @@ def _convert_by(by): @Substitution("\ndata : DataFrame") @Appender(_shared_docs["pivot"], indents=1) -def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFrame": +def pivot( + data: "DataFrame", + index: Optional[Union[Label, Sequence[Label]]] = None, + columns: Optional[Union[Label, Sequence[Label]]] = None, + values: Optional[Union[Label, Sequence[Label]]] = None, +) -> "DataFrame": if columns is None: raise TypeError("pivot() missing 1 required argument: 'columns'") - columns = columns if is_list_like(columns) else [columns] + + columns = com.convert_to_list_like(columns) if values is None: - cols: List[str] = [] - if index is None: - pass - elif is_list_like(index): - cols = list(index) + if index is not None: + cols = com.convert_to_list_like(index) else: - cols = [index] + cols = [] cols.extend(columns) append = index is None @@ -444,10 +458,9 @@ def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFram else: if index is None: index = [Series(data.index, name=data.index.name)] - elif is_list_like(index): - index = [data[idx] for idx in index] else: - index = [data[index]] + index = com.convert_to_list_like(index) + index = [data[idx] for idx in index] data_columns = [data[col] for col in columns] index.extend(data_columns) @@ -455,6 +468,7 @@ def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFram if is_list_like(values) and not isinstance(values, tuple): # Exclude tuple because it is seen as a single column name + values = cast(Sequence[Label], values) indexed = data._constructor( data[values]._values, index=index, columns=values )