Skip to content

Commit 66df1b0

Browse files
authored
Try to read pandas metadata in read_parquet if index_col is None. (#1695)
If a parquet file is stored by pandas, there is a metadata to describe index columns. We can read it and use as `index_col` if it's not specified. Resolves #1645.
1 parent faf687f commit 66df1b0

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

databricks/koalas/namespace.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,17 @@
2323
from distutils.version import LooseVersion
2424
from functools import reduce
2525
from io import BytesIO
26+
import json
2627

2728
import numpy as np
2829
import pandas as pd
2930
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype, is_list_like
31+
import pyarrow as pa
32+
import pyarrow.parquet as pq
3033
import pyspark
3134
from pyspark import sql as spark
3235
from pyspark.sql import functions as F
36+
from pyspark.sql.functions import pandas_udf, PandasUDFType
3337
from pyspark.sql.types import (
3438
ByteType,
3539
ShortType,
@@ -625,7 +629,7 @@ def read_spark_io(
625629
return DataFrame(InternalFrame(spark_frame=sdf, index_map=index_map))
626630

627631

628-
def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
632+
def read_parquet(path, columns=None, index_col=None, pandas_metadata=False, **options) -> DataFrame:
629633
"""Load a parquet object from the file path, returning a DataFrame.
630634
631635
Parameters
@@ -636,6 +640,8 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
636640
If not None, only these columns will be read from the file.
637641
index_col : str or list of str, optional, default: None
638642
Index column of table in Spark.
643+
pandas_metadata : bool, default: False
644+
If True, try to respect the metadata if the Parquet file is written from pandas.
639645
options : dict
640646
All other options passed directly into Spark's data source.
641647
@@ -672,6 +678,44 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
672678
if columns is not None:
673679
columns = list(columns)
674680

681+
index_names = None
682+
683+
if index_col is None and pandas_metadata:
684+
if LooseVersion(pyspark.__version__) < LooseVersion("3.0.0"):
685+
raise ValueError("pandas_metadata is not supported with Spark < 3.0.")
686+
687+
# Try to read pandas metadata
688+
689+
@pandas_udf("index_col array<string>, index_names array<string>", PandasUDFType.SCALAR)
690+
def read_index_metadata(pser):
691+
binary = pser.iloc[0]
692+
metadata = pq.ParquetFile(pa.BufferReader(binary)).metadata.metadata
693+
if b"pandas" in metadata:
694+
pandas_metadata = json.loads(metadata[b"pandas"].decode("utf8"))
695+
if all(isinstance(col, str) for col in pandas_metadata["index_columns"]):
696+
index_col = []
697+
index_names = []
698+
for col in pandas_metadata["index_columns"]:
699+
index_col.append(col)
700+
for column in pandas_metadata["columns"]:
701+
if column["field_name"] == col:
702+
index_names.append(column["name"])
703+
break
704+
else:
705+
index_names.append(None)
706+
return pd.DataFrame({"index_col": [index_col], "index_names": [index_names]})
707+
return pd.DataFrame({"index_col": [None], "index_names": [None]})
708+
709+
index_col, index_names = (
710+
default_session()
711+
.read.format("binaryFile")
712+
.load(path)
713+
.limit(1)
714+
.select(read_index_metadata("content").alias("index_metadata"))
715+
.select("index_metadata.*")
716+
.head()
717+
)
718+
675719
kdf = read_spark_io(path=path, format="parquet", options=options, index_col=index_col)
676720

677721
if columns is not None:
@@ -681,7 +725,10 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
681725
else:
682726
sdf = default_session().createDataFrame([], schema=StructType())
683727
index_map = _get_index_map(sdf, index_col)
684-
return DataFrame(InternalFrame(spark_frame=sdf, index_map=index_map))
728+
kdf = DataFrame(InternalFrame(spark_frame=sdf, index_map=index_map))
729+
730+
if index_names is not None:
731+
kdf.index.names = index_names
685732

686733
return kdf
687734

databricks/koalas/tests/test_dataframe_spark_io.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#
1616

1717
from distutils.version import LooseVersion
18+
import unittest
1819

1920
import numpy as np
2021
import pandas as pd
@@ -92,6 +93,33 @@ def check(columns, expected):
9293
expected_idx.sort_values(by="f").to_spark().toPandas(),
9394
)
9495

96+
@unittest.skipIf(
97+
LooseVersion(pyspark.__version__) < LooseVersion("3.0.0"),
98+
"The test only works with Spark>=3.0",
99+
)
100+
def test_parquet_read_with_pandas_metadata(self):
101+
with self.temp_dir() as tmp:
102+
expected1 = self.test_pdf
103+
104+
path1 = "{}/file1.parquet".format(tmp)
105+
expected1.to_parquet(path1)
106+
107+
self.assert_eq(ks.read_parquet(path1, pandas_metadata=True), expected1)
108+
109+
expected2 = expected1.reset_index()
110+
111+
path2 = "{}/file2.parquet".format(tmp)
112+
expected2.to_parquet(path2)
113+
114+
self.assert_eq(ks.read_parquet(path2, pandas_metadata=True), expected2)
115+
116+
expected3 = expected2.set_index("index", append=True)
117+
118+
path3 = "{}/file3.parquet".format(tmp)
119+
expected3.to_parquet(path3)
120+
121+
self.assert_eq(ks.read_parquet(path3, pandas_metadata=True), expected3)
122+
95123
def test_parquet_write(self):
96124
with self.temp_dir() as tmp:
97125
pdf = self.test_pdf

0 commit comments

Comments
 (0)