2525
2626import numpy as np
2727import pandas as pd
28- from pandas .api .types import CategoricalDtype , is_datetime64_dtype , is_datetime64tz_dtype
28+ from pandas .api .types import CategoricalDtype # noqa: F401
2929from pyspark import sql as spark
3030from pyspark ._globals import _NoValue , _NoValueType
3131from pyspark .sql import functions as F , Window
3232from pyspark .sql .functions import PandasUDFType , pandas_udf
33- from pyspark .sql .types import BooleanType , DataType , StructField , StructType , LongType
33+ from pyspark .sql .types import ( # noqa: F401
34+ BooleanType ,
35+ DataType ,
36+ IntegralType ,
37+ LongType ,
38+ StructField ,
39+ StructType ,
40+ StringType ,
41+ )
3442
3543# For running doctests and reference resolution in PyCharm.
3644from pyspark import pandas as ps # noqa: F401
3947 # This is required in old Python 3.5 to prevent circular reference.
4048 from pyspark .pandas .series import Series # noqa: F401 (SPARK-34943)
4149from pyspark .pandas .config import get_option
50+ from pyspark .pandas .data_type_ops .base import DataTypeOps
4251from pyspark .pandas .typedef import (
4352 Dtype ,
4453 as_spark_type ,
@@ -951,11 +960,11 @@ def arguments_for_restore_index(self) -> Dict:
951960 for col , dtype in zip (self .index_spark_column_names , self .index_dtypes )
952961 if isinstance (dtype , extension_dtypes )
953962 }
954- categorical_dtypes = {
955- col : dtype
956- for col , dtype in zip ( self .index_spark_column_names , self . index_dtypes )
957- if isinstance ( dtype , CategoricalDtype )
958- }
963+ dtypes = [ dtype for dtype in self . index_dtypes ]
964+ spark_types = [
965+ self . spark_frame . select ( scol ). schema [ 0 ]. dataType for scol in self .index_spark_columns
966+ ]
967+
959968 for spark_column , column_name , dtype in zip (
960969 self .data_spark_columns , self .data_spark_column_names , self .data_dtypes
961970 ):
@@ -969,8 +978,8 @@ def arguments_for_restore_index(self) -> Dict:
969978 column_names .append (column_name )
970979 if isinstance (dtype , extension_dtypes ):
971980 ext_dtypes [column_name ] = dtype
972- elif isinstance (dtype , CategoricalDtype ):
973- categorical_dtypes [ column_name ] = dtype
981+ dtypes . append (dtype )
982+ spark_types . append ( self . spark_frame . select ( spark_column ). schema [ 0 ]. dataType )
974983
975984 return dict (
976985 index_columns = self .index_spark_column_names ,
@@ -979,7 +988,8 @@ def arguments_for_restore_index(self) -> Dict:
979988 column_labels = self .column_labels ,
980989 column_label_names = self .column_label_names ,
981990 ext_dtypes = ext_dtypes ,
982- categorical_dtypes = categorical_dtypes ,
991+ dtypes = dtypes ,
992+ spark_types = spark_types ,
983993 )
984994
985995 @staticmethod
@@ -991,8 +1001,9 @@ def restore_index(
9911001 data_columns : List [str ],
9921002 column_labels : List [Tuple ],
9931003 column_label_names : List [Tuple ],
1004+ dtypes : List [Dtype ],
1005+ spark_types : List [DataType ],
9941006 ext_dtypes : Dict [str , Dtype ] = None ,
995- categorical_dtypes : Dict [str , CategoricalDtype ] = None
9961007 ) -> pd .DataFrame :
9971008 """
9981009 Restore pandas DataFrame indices using the metadata.
@@ -1003,10 +1014,12 @@ def restore_index(
10031014 :param data_columns: the original column names for data columns.
10041015 :param column_labels: the column labels after restored.
10051016 :param column_label_names: the column label names after restored.
1017+ :param dtypes: the dtypes after restored.
1018+ :param spark_types: the spark_types.
10061019 :param ext_dtypes: the map from the original column names to extension data types.
1007- :param categorical_dtypes: the map from the original column names to categorical types.
10081020 :return: the restored pandas DataFrame
10091021
1022+ >>> from numpy import dtype
10101023 >>> pdf = pd.DataFrame({"index": [10, 20, 30], "a": ['a', 'b', 'c'], "b": [0, 2, 1]})
10111024 >>> InternalFrame.restore_index(
10121025 ... pdf,
@@ -1015,8 +1028,9 @@ def restore_index(
10151028 ... data_columns=["a", "b", "index"],
10161029 ... column_labels=[("x",), ("y",), ("z",)],
10171030 ... column_label_names=[("lv1",)],
1018- ... ext_dtypes=None,
1019- ... categorical_dtypes={"b": CategoricalDtype(categories=["i", "j", "k"])}
1031+ ... dtypes=[dtype('int64'), dtype('object'),
1032+ ... CategoricalDtype(categories=["i", "j", "k"]), dtype('int64')],
1033+ ... spark_types=[LongType(), StringType(), StringType(), LongType()]
10201034 ... ) # doctest: +NORMALIZE_WHITESPACE
10211035 lv1 x y z
10221036 idx
@@ -1027,11 +1041,8 @@ def restore_index(
10271041 if ext_dtypes is not None and len (ext_dtypes ) > 0 :
10281042 pdf = pdf .astype (ext_dtypes , copy = True )
10291043
1030- if categorical_dtypes is not None :
1031- for col , dtype in categorical_dtypes .items ():
1032- pdf [col ] = pd .Categorical .from_codes (
1033- pdf [col ], categories = dtype .categories , ordered = dtype .ordered
1034- )
1044+ for col , expected_dtype , spark_type in zip (pdf .columns , dtypes , spark_types ):
1045+ pdf [col ] = DataTypeOps (expected_dtype , spark_type ).restore (pdf [col ])
10351046
10361047 append = False
10371048 for index_field in index_columns :
@@ -1071,7 +1082,7 @@ def with_new_sdf(
10711082 * ,
10721083 index_dtypes : Optional [List [Dtype ]] = None ,
10731084 data_columns : Optional [List [str ]] = None ,
1074- data_dtypes : Optional [List [Dtype ]] = None
1085+ data_dtypes : Optional [List [Dtype ]] = None ,
10751086 ) -> "InternalFrame" :
10761087 """Copy the immutable InternalFrame with the updates by the specified Spark DataFrame.
10771088
@@ -1121,7 +1132,7 @@ def with_new_columns(
11211132 column_labels : Optional [List [Tuple ]] = None ,
11221133 data_dtypes : Optional [List [Dtype ]] = None ,
11231134 column_label_names : Union [Optional [List [Optional [Tuple ]]], _NoValueType ] = _NoValue ,
1124- keep_order : bool = True
1135+ keep_order : bool = True ,
11251136 ) -> "InternalFrame" :
11261137 """
11271138 Copy the immutable InternalFrame with the updates by the specified Spark Columns or Series.
@@ -1225,7 +1236,7 @@ def with_new_spark_column(
12251236 scol : spark .Column ,
12261237 * ,
12271238 dtype : Optional [Dtype ] = None ,
1228- keep_order : bool = True
1239+ keep_order : bool = True ,
12291240 ) -> "InternalFrame" :
12301241 """
12311242 Copy the immutable InternalFrame with the updates by the specified Spark Column.
@@ -1273,7 +1284,7 @@ def copy(
12731284 column_labels : Union [Optional [List [Tuple ]], _NoValueType ] = _NoValue ,
12741285 data_spark_columns : Union [Optional [List [spark .Column ]], _NoValueType ] = _NoValue ,
12751286 data_dtypes : Union [Optional [List [Dtype ]], _NoValueType ] = _NoValue ,
1276- column_label_names : Union [Optional [List [Optional [Tuple ]]], _NoValueType ] = _NoValue
1287+ column_label_names : Union [Optional [List [Optional [Tuple ]]], _NoValueType ] = _NoValue ,
12771288 ) -> "InternalFrame" :
12781289 """Copy the immutable InternalFrame.
12791290
@@ -1423,13 +1434,9 @@ def prepare_pandas_frame(
14231434 index_dtypes = list (reset_index .dtypes )[:index_nlevels ]
14241435 data_dtypes = list (reset_index .dtypes )[index_nlevels :]
14251436
1426- for name , col in reset_index .iteritems ():
1427- dt = col .dtype
1428- if is_datetime64_dtype (dt ) or is_datetime64tz_dtype (dt ):
1429- continue
1430- elif isinstance (dt , CategoricalDtype ):
1431- col = col .cat .codes
1432- reset_index [name ] = col .replace ({np .nan : None })
1437+ for col , dtype in zip (reset_index .columns , reset_index .dtypes ):
1438+ spark_type = infer_pd_series_spark_type (reset_index [col ], dtype )
1439+ reset_index [col ] = DataTypeOps (dtype , spark_type ).prepare (reset_index [col ])
14331440
14341441 return reset_index , index_columns , index_dtypes , data_columns , data_dtypes
14351442
0 commit comments