Skip to content

Commit 04a8d2c

Browse files
xinrong-mengueshin
authored andcommitted
[SPARK-35343][PYTHON] Make the conversion from/to pandas data-type-based for non-ExtensionDtypes
### What changes were proposed in this pull request? Make the conversion from/to pandas (for non-ExtensionDtype) data-type-based. NOTE: Ops class per ExtensionDtype and its data-type-based from/to pandas will be implemented in a separate PR as https://issues.apache.org/jira/browse/SPARK-35614. ### Why are the changes needed? The conversion from/to pandas includes logic for checking data types and behaving accordingly. That makes code hard to change or maintain. Since we have introduced the Ops class per non-ExtensionDtype data type, we ought to make the conversion from/to pandas data-type-based for non-ExtensionDtypes. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #32592 from xinrong-databricks/datatypeop_pd_conversion. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 6c3b7f9 commit 04a8d2c

File tree

19 files changed

+481
-37
lines changed

19 files changed

+481
-37
lines changed

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,10 @@ def __hash__(self):
615615
"pyspark.pandas.tests.data_type_ops.test_complex_ops",
616616
"pyspark.pandas.tests.data_type_ops.test_date_ops",
617617
"pyspark.pandas.tests.data_type_ops.test_datetime_ops",
618+
"pyspark.pandas.tests.data_type_ops.test_null_ops",
618619
"pyspark.pandas.tests.data_type_ops.test_num_ops",
619620
"pyspark.pandas.tests.data_type_ops.test_string_ops",
621+
"pyspark.pandas.tests.data_type_ops.test_udt_ops",
620622
"pyspark.pandas.tests.indexes.test_category",
621623
"pyspark.pandas.tests.plot.test_frame_plot",
622624
"pyspark.pandas.tests.plot.test_frame_plot_matplotlib",

python/pyspark/pandas/data_type_ops/base.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
#
1717

1818
import numbers
19-
from abc import ABCMeta, abstractmethod
19+
from abc import ABCMeta
2020
from typing import Any, TYPE_CHECKING, Union
2121

22+
import numpy as np
23+
import pandas as pd
2224
from pandas.api.types import CategoricalDtype
2325

2426
from pyspark.sql.types import (
@@ -30,14 +32,15 @@
3032
FractionalType,
3133
IntegralType,
3234
MapType,
35+
NullType,
3336
NumericType,
3437
StringType,
3538
StructType,
3639
TimestampType,
40+
UserDefinedType,
3741
)
3842

3943
import pyspark.sql.types as types
40-
from pyspark.pandas.base import IndexOpsMixin
4144
from pyspark.pandas.typedef import Dtype
4245

4346
if TYPE_CHECKING:
@@ -47,6 +50,8 @@
4750

4851
def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool = True) -> bool:
4952
"""Check whether the operand is valid for arithmetic operations against numerics."""
53+
from pyspark.pandas.base import IndexOpsMixin
54+
5055
if isinstance(operand, numbers.Number) and not isinstance(operand, bool):
5156
return True
5257
elif isinstance(operand, IndexOpsMixin):
@@ -66,6 +71,8 @@ def transform_boolean_operand_to_numeric(operand: Any, spark_type: types.DataTyp
6671
Return the transformed operand if the operand is a boolean IndexOpsMixin,
6772
otherwise return the original operand.
6873
"""
74+
from pyspark.pandas.base import IndexOpsMixin
75+
6976
if isinstance(operand, IndexOpsMixin) and isinstance(operand.spark.data_type, BooleanType):
7077
return operand.spark.transform(lambda scol: scol.cast(spark_type))
7178
else:
@@ -82,11 +89,13 @@ def __new__(cls, dtype: Dtype, spark_type: DataType):
8289
from pyspark.pandas.data_type_ops.complex_ops import ArrayOps, MapOps, StructOps
8390
from pyspark.pandas.data_type_ops.date_ops import DateOps
8491
from pyspark.pandas.data_type_ops.datetime_ops import DatetimeOps
92+
from pyspark.pandas.data_type_ops.null_ops import NullOps
8593
from pyspark.pandas.data_type_ops.num_ops import (
8694
IntegralOps,
8795
FractionalOps,
8896
)
8997
from pyspark.pandas.data_type_ops.string_ops import StringOps
98+
from pyspark.pandas.data_type_ops.udt_ops import UDTOps
9099

91100
if isinstance(dtype, CategoricalDtype):
92101
return object.__new__(CategoricalOps)
@@ -110,6 +119,10 @@ def __new__(cls, dtype: Dtype, spark_type: DataType):
110119
return object.__new__(MapOps)
111120
elif isinstance(spark_type, StructType):
112121
return object.__new__(StructOps)
122+
elif isinstance(spark_type, NullType):
123+
return object.__new__(NullOps)
124+
elif isinstance(spark_type, UserDefinedType):
125+
return object.__new__(UDTOps)
113126
else:
114127
raise TypeError("Type %s was not understood." % dtype)
115128

@@ -118,7 +131,6 @@ def __init__(self, dtype: Dtype, spark_type: DataType):
118131
self.spark_type = spark_type
119132

120133
@property
121-
@abstractmethod
122134
def pretty_name(self) -> str:
123135
raise NotImplementedError()
124136

@@ -163,3 +175,11 @@ def rmod(self, left, right) -> Union["Series", "Index"]:
163175

164176
def rpow(self, left, right) -> Union["Series", "Index"]:
165177
raise TypeError("Exponentiation can not be applied to %s." % self.pretty_name)
178+
179+
def restore(self, col: pd.Series) -> pd.Series:
180+
"""Restore column when to_pandas."""
181+
return col
182+
183+
def prepare(self, col: pd.Series) -> pd.Series:
184+
"""Prepare column when from_pandas."""
185+
return col.replace({np.nan: None})

python/pyspark/pandas/data_type_ops/categorical_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18+
import pandas as pd
19+
1820
from pyspark.pandas.data_type_ops.base import DataTypeOps
1921

2022

@@ -26,3 +28,13 @@ class CategoricalOps(DataTypeOps):
2628
@property
2729
def pretty_name(self) -> str:
2830
return "categoricals"
31+
32+
def restore(self, col: pd.Series) -> pd.Series:
33+
"""Restore column when to_pandas."""
34+
return pd.Categorical.from_codes(
35+
col, categories=self.dtype.categories, ordered=self.dtype.ordered
36+
)
37+
38+
def prepare(self, col: pd.Series) -> pd.Series:
39+
"""Prepare column when from_pandas."""
40+
return col.cat.codes

python/pyspark/pandas/data_type_ops/datetime_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@ def rsub(self, left, right) -> Union["Series", "Index"]:
7474
)
7575
else:
7676
raise TypeError("datetime subtraction can only be applied to datetime series.")
77+
78+
def prepare(self, col):
79+
"""Prepare column when from_pandas."""
80+
return col
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.pandas.data_type_ops.base import DataTypeOps
19+
20+
21+
class NullOps(DataTypeOps):
22+
"""
23+
The class for binary operations of pandas-on-Spark objects with Spark type: NullType.
24+
"""
25+
26+
@property
27+
def pretty_name(self) -> str:
28+
return "nulls"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.pandas.data_type_ops.base import DataTypeOps
19+
20+
21+
class UDTOps(DataTypeOps):
22+
"""
23+
The class for binary operations of pandas-on-Spark objects with Spark type:
24+
UserDefinedType or its subclasses.
25+
"""
26+
27+
@property
28+
def pretty_name(self) -> str:
29+
return "user defined types"

python/pyspark/pandas/internal.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,20 @@
2525

2626
import numpy as np
2727
import pandas as pd
28-
from pandas.api.types import CategoricalDtype, is_datetime64_dtype, is_datetime64tz_dtype
28+
from pandas.api.types import CategoricalDtype # noqa: F401
2929
from pyspark import sql as spark
3030
from pyspark._globals import _NoValue, _NoValueType
3131
from pyspark.sql import functions as F, Window
3232
from pyspark.sql.functions import PandasUDFType, pandas_udf
33-
from pyspark.sql.types import BooleanType, DataType, StructField, StructType, LongType
33+
from pyspark.sql.types import ( # noqa: F401
34+
BooleanType,
35+
DataType,
36+
IntegralType,
37+
LongType,
38+
StructField,
39+
StructType,
40+
StringType,
41+
)
3442

3543
# For running doctests and reference resolution in PyCharm.
3644
from pyspark import pandas as ps # noqa: F401
@@ -39,6 +47,7 @@
3947
# This is required in old Python 3.5 to prevent circular reference.
4048
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
4149
from pyspark.pandas.config import get_option
50+
from pyspark.pandas.data_type_ops.base import DataTypeOps
4251
from pyspark.pandas.typedef import (
4352
Dtype,
4453
as_spark_type,
@@ -951,11 +960,11 @@ def arguments_for_restore_index(self) -> Dict:
951960
for col, dtype in zip(self.index_spark_column_names, self.index_dtypes)
952961
if isinstance(dtype, extension_dtypes)
953962
}
954-
categorical_dtypes = {
955-
col: dtype
956-
for col, dtype in zip(self.index_spark_column_names, self.index_dtypes)
957-
if isinstance(dtype, CategoricalDtype)
958-
}
963+
dtypes = [dtype for dtype in self.index_dtypes]
964+
spark_types = [
965+
self.spark_frame.select(scol).schema[0].dataType for scol in self.index_spark_columns
966+
]
967+
959968
for spark_column, column_name, dtype in zip(
960969
self.data_spark_columns, self.data_spark_column_names, self.data_dtypes
961970
):
@@ -969,8 +978,8 @@ def arguments_for_restore_index(self) -> Dict:
969978
column_names.append(column_name)
970979
if isinstance(dtype, extension_dtypes):
971980
ext_dtypes[column_name] = dtype
972-
elif isinstance(dtype, CategoricalDtype):
973-
categorical_dtypes[column_name] = dtype
981+
dtypes.append(dtype)
982+
spark_types.append(self.spark_frame.select(spark_column).schema[0].dataType)
974983

975984
return dict(
976985
index_columns=self.index_spark_column_names,
@@ -979,7 +988,8 @@ def arguments_for_restore_index(self) -> Dict:
979988
column_labels=self.column_labels,
980989
column_label_names=self.column_label_names,
981990
ext_dtypes=ext_dtypes,
982-
categorical_dtypes=categorical_dtypes,
991+
dtypes=dtypes,
992+
spark_types=spark_types,
983993
)
984994

985995
@staticmethod
@@ -991,8 +1001,9 @@ def restore_index(
9911001
data_columns: List[str],
9921002
column_labels: List[Tuple],
9931003
column_label_names: List[Tuple],
1004+
dtypes: List[Dtype],
1005+
spark_types: List[DataType],
9941006
ext_dtypes: Dict[str, Dtype] = None,
995-
categorical_dtypes: Dict[str, CategoricalDtype] = None
9961007
) -> pd.DataFrame:
9971008
"""
9981009
Restore pandas DataFrame indices using the metadata.
@@ -1003,10 +1014,12 @@ def restore_index(
10031014
:param data_columns: the original column names for data columns.
10041015
:param column_labels: the column labels after restored.
10051016
:param column_label_names: the column label names after restored.
1017+
:param dtypes: the dtypes after restored.
1018+
:param spark_types: the spark_types.
10061019
:param ext_dtypes: the map from the original column names to extension data types.
1007-
:param categorical_dtypes: the map from the original column names to categorical types.
10081020
:return: the restored pandas DataFrame
10091021
1022+
>>> from numpy import dtype
10101023
>>> pdf = pd.DataFrame({"index": [10, 20, 30], "a": ['a', 'b', 'c'], "b": [0, 2, 1]})
10111024
>>> InternalFrame.restore_index(
10121025
... pdf,
@@ -1015,8 +1028,9 @@ def restore_index(
10151028
... data_columns=["a", "b", "index"],
10161029
... column_labels=[("x",), ("y",), ("z",)],
10171030
... column_label_names=[("lv1",)],
1018-
... ext_dtypes=None,
1019-
... categorical_dtypes={"b": CategoricalDtype(categories=["i", "j", "k"])}
1031+
... dtypes=[dtype('int64'), dtype('object'),
1032+
... CategoricalDtype(categories=["i", "j", "k"]), dtype('int64')],
1033+
... spark_types=[LongType(), StringType(), StringType(), LongType()]
10201034
... ) # doctest: +NORMALIZE_WHITESPACE
10211035
lv1 x y z
10221036
idx
@@ -1027,11 +1041,8 @@ def restore_index(
10271041
if ext_dtypes is not None and len(ext_dtypes) > 0:
10281042
pdf = pdf.astype(ext_dtypes, copy=True)
10291043

1030-
if categorical_dtypes is not None:
1031-
for col, dtype in categorical_dtypes.items():
1032-
pdf[col] = pd.Categorical.from_codes(
1033-
pdf[col], categories=dtype.categories, ordered=dtype.ordered
1034-
)
1044+
for col, expected_dtype, spark_type in zip(pdf.columns, dtypes, spark_types):
1045+
pdf[col] = DataTypeOps(expected_dtype, spark_type).restore(pdf[col])
10351046

10361047
append = False
10371048
for index_field in index_columns:
@@ -1071,7 +1082,7 @@ def with_new_sdf(
10711082
*,
10721083
index_dtypes: Optional[List[Dtype]] = None,
10731084
data_columns: Optional[List[str]] = None,
1074-
data_dtypes: Optional[List[Dtype]] = None
1085+
data_dtypes: Optional[List[Dtype]] = None,
10751086
) -> "InternalFrame":
10761087
"""Copy the immutable InternalFrame with the updates by the specified Spark DataFrame.
10771088
@@ -1121,7 +1132,7 @@ def with_new_columns(
11211132
column_labels: Optional[List[Tuple]] = None,
11221133
data_dtypes: Optional[List[Dtype]] = None,
11231134
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue,
1124-
keep_order: bool = True
1135+
keep_order: bool = True,
11251136
) -> "InternalFrame":
11261137
"""
11271138
Copy the immutable InternalFrame with the updates by the specified Spark Columns or Series.
@@ -1225,7 +1236,7 @@ def with_new_spark_column(
12251236
scol: spark.Column,
12261237
*,
12271238
dtype: Optional[Dtype] = None,
1228-
keep_order: bool = True
1239+
keep_order: bool = True,
12291240
) -> "InternalFrame":
12301241
"""
12311242
Copy the immutable InternalFrame with the updates by the specified Spark Column.
@@ -1273,7 +1284,7 @@ def copy(
12731284
column_labels: Union[Optional[List[Tuple]], _NoValueType] = _NoValue,
12741285
data_spark_columns: Union[Optional[List[spark.Column]], _NoValueType] = _NoValue,
12751286
data_dtypes: Union[Optional[List[Dtype]], _NoValueType] = _NoValue,
1276-
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue
1287+
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue,
12771288
) -> "InternalFrame":
12781289
"""Copy the immutable InternalFrame.
12791290
@@ -1423,13 +1434,9 @@ def prepare_pandas_frame(
14231434
index_dtypes = list(reset_index.dtypes)[:index_nlevels]
14241435
data_dtypes = list(reset_index.dtypes)[index_nlevels:]
14251436

1426-
for name, col in reset_index.iteritems():
1427-
dt = col.dtype
1428-
if is_datetime64_dtype(dt) or is_datetime64tz_dtype(dt):
1429-
continue
1430-
elif isinstance(dt, CategoricalDtype):
1431-
col = col.cat.codes
1432-
reset_index[name] = col.replace({np.nan: None})
1437+
for col, dtype in zip(reset_index.columns, reset_index.dtypes):
1438+
spark_type = infer_pd_series_spark_type(reset_index[col], dtype)
1439+
reset_index[col] = DataTypeOps(dtype, spark_type).prepare(reset_index[col])
14331440

14341441
return reset_index, index_columns, index_dtypes, data_columns, data_dtypes
14351442

python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ def test_rpow(self):
122122
self.assertRaises(TypeError, lambda: "x" ** self.psser)
123123
self.assertRaises(TypeError, lambda: 1 ** self.psser)
124124

125+
def test_from_to_pandas(self):
126+
data = [b"1", b"2", b"3"]
127+
pser = pd.Series(data)
128+
psser = ps.Series(data)
129+
self.assert_eq(pser, psser.to_pandas())
130+
self.assert_eq(ps.from_pandas(pser), psser)
131+
125132

126133
if __name__ == "__main__":
127134
import unittest

python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ def test_rmod(self):
229229
self.assertRaises(TypeError, lambda: datetime.date(1994, 1, 1) % self.psser)
230230
self.assertRaises(TypeError, lambda: True % self.psser)
231231

232+
def test_from_to_pandas(self):
233+
data = [True, True, False]
234+
pser = pd.Series(data)
235+
psser = ps.Series(data)
236+
self.assert_eq(pser, psser.to_pandas())
237+
self.assert_eq(ps.from_pandas(pser), psser)
238+
232239

233240
if __name__ == "__main__":
234241
import unittest

0 commit comments

Comments
 (0)