2323from distutils .version import LooseVersion
2424from functools import reduce
2525from io import BytesIO
26+ import json
2627
2728import numpy as np
2829import pandas as pd
2930from pandas .api .types import is_datetime64_dtype , is_datetime64tz_dtype , is_list_like
31+ import pyarrow as pa
32+ import pyarrow .parquet as pq
3033import pyspark
3134from pyspark import sql as spark
3235from pyspark .sql import functions as F
36+ from pyspark .sql .functions import pandas_udf , PandasUDFType
3337from pyspark .sql .types import (
3438 ByteType ,
3539 ShortType ,
@@ -625,7 +629,7 @@ def read_spark_io(
625629 return DataFrame (InternalFrame (spark_frame = sdf , index_map = index_map ))
626630
627631
628- def read_parquet (path , columns = None , index_col = None , ** options ) -> DataFrame :
632+ def read_parquet (path , columns = None , index_col = None , pandas_metadata = False , ** options ) -> DataFrame :
629633 """Load a parquet object from the file path, returning a DataFrame.
630634
631635 Parameters
@@ -636,6 +640,8 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
636640 If not None, only these columns will be read from the file.
637641 index_col : str or list of str, optional, default: None
638642 Index column of table in Spark.
643+ pandas_metadata : bool, default: False
644+ If True, try to respect the metadata if the Parquet file is written from pandas.
639645 options : dict
640646 All other options passed directly into Spark's data source.
641647
@@ -672,6 +678,44 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
672678 if columns is not None :
673679 columns = list (columns )
674680
681+ index_names = None
682+
683+ if index_col is None and pandas_metadata :
684+ if LooseVersion (pyspark .__version__ ) < LooseVersion ("3.0.0" ):
685+ raise ValueError ("pandas_metadata is not supported with Spark < 3.0." )
686+
687+ # Try to read pandas metadata
688+
689+ @pandas_udf ("index_col array<string>, index_names array<string>" , PandasUDFType .SCALAR )
690+ def read_index_metadata (pser ):
691+ binary = pser .iloc [0 ]
692+ metadata = pq .ParquetFile (pa .BufferReader (binary )).metadata .metadata
693+ if b"pandas" in metadata :
694+ pandas_metadata = json .loads (metadata [b"pandas" ].decode ("utf8" ))
695+ if all (isinstance (col , str ) for col in pandas_metadata ["index_columns" ]):
696+ index_col = []
697+ index_names = []
698+ for col in pandas_metadata ["index_columns" ]:
699+ index_col .append (col )
700+ for column in pandas_metadata ["columns" ]:
701+ if column ["field_name" ] == col :
702+ index_names .append (column ["name" ])
703+ break
704+ else :
705+ index_names .append (None )
706+ return pd .DataFrame ({"index_col" : [index_col ], "index_names" : [index_names ]})
707+ return pd .DataFrame ({"index_col" : [None ], "index_names" : [None ]})
708+
709+ index_col , index_names = (
710+ default_session ()
711+ .read .format ("binaryFile" )
712+ .load (path )
713+ .limit (1 )
714+ .select (read_index_metadata ("content" ).alias ("index_metadata" ))
715+ .select ("index_metadata.*" )
716+ .head ()
717+ )
718+
675719 kdf = read_spark_io (path = path , format = "parquet" , options = options , index_col = index_col )
676720
677721 if columns is not None :
@@ -681,7 +725,10 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
681725 else :
682726 sdf = default_session ().createDataFrame ([], schema = StructType ())
683727 index_map = _get_index_map (sdf , index_col )
684- return DataFrame (InternalFrame (spark_frame = sdf , index_map = index_map ))
728+ kdf = DataFrame (InternalFrame (spark_frame = sdf , index_map = index_map ))
729+
730+ if index_names is not None :
731+ kdf .index .names = index_names
685732
686733 return kdf
687734
0 commit comments