Skip to content

Commit e1908ec

Browse files
committed
Add spark namespace in DataFrame, Series, Index and MultiIndex
1 parent 8530f1b commit e1908ec

File tree

22 files changed

+1176
-756
lines changed

22 files changed

+1176
-756
lines changed

databricks/koalas/base.py

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
from collections import OrderedDict
2121
from functools import wraps, partial
2222
from typing import Union, Callable, Any
23+
import warnings
2324

2425
import numpy as np
2526
import pandas as pd
2627
from pandas.api.types import is_list_like
27-
from pyspark import sql as spark
28-
from pyspark.sql import functions as F, Window
28+
from pandas.core.accessor import CachedAccessor
29+
from pyspark.sql import functions as F, Window, Column
2930
from pyspark.sql.types import DateType, DoubleType, FloatType, LongType, StringType, TimestampType
3031

3132
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
@@ -35,6 +36,7 @@
3536
NATURAL_ORDER_COLUMN_NAME,
3637
SPARK_DEFAULT_INDEX_NAME,
3738
)
39+
from databricks.koalas.spark import SparkIndexOpsMethods
3840
from databricks.koalas.typedef import spark_type_to_pandas_dtype
3941
from databricks.koalas.utils import align_diff_series, same_anchor, scol_for, validate_axis
4042
from databricks.koalas.frame import DataFrame
@@ -45,19 +47,19 @@ def booleanize_null(left_scol, scol, f):
4547
Booleanize Null in Spark Column
4648
"""
4749
comp_ops = [
48-
getattr(spark.Column, "__{}__".format(comp_op))
50+
getattr(Column, "__{}__".format(comp_op))
4951
for comp_op in ["eq", "ne", "lt", "le", "ge", "gt"]
5052
]
5153

5254
if f in comp_ops:
5355
# if `f` is "!=", fill null with True otherwise False
54-
filler = f == spark.Column.__ne__
56+
filler = f == Column.__ne__
5557
scol = F.when(scol.isNull(), filler).otherwise(scol)
5658

57-
elif f == spark.Column.__or__:
59+
elif f == Column.__or__:
5860
scol = F.when(left_scol.isNull() | scol.isNull(), False).otherwise(scol)
5961

60-
elif f == spark.Column.__and__:
62+
elif f == Column.__and__:
6163
scol = F.when(scol.isNull(), False).otherwise(scol)
6264

6365
return scol
@@ -83,9 +85,9 @@ def wrapper(self, *args):
8385
cols = [arg for arg in args if isinstance(arg, IndexOpsMixin)]
8486
if all(same_anchor(self, col) for col in cols):
8587
# Same DataFrame anchors
86-
args = [arg.spark_column if isinstance(arg, IndexOpsMixin) else arg for arg in args]
87-
scol = f(self.spark_column, *args)
88-
scol = booleanize_null(self.spark_column, scol, f)
88+
args = [arg.spark.column if isinstance(arg, IndexOpsMixin) else arg for arg in args]
89+
scol = f(self.spark.column, *args)
90+
scol = booleanize_null(self.spark.column, scol, f)
8991

9092
return self._with_new_scol(scol)
9193
else:
@@ -107,7 +109,7 @@ def wrapper(self, *args):
107109
new_args = []
108110
for arg in args:
109111
# TODO: This is a quick hack to support NumPy type. We should revisit this.
110-
if isinstance(self.spark_type, LongType) and isinstance(arg, np.timedelta64):
112+
if isinstance(self.spark.type, LongType) and isinstance(arg, np.timedelta64):
111113
new_args.append(float(arg / np.timedelta64(1, "s")))
112114
else:
113115
new_args.append(arg)
@@ -121,13 +123,10 @@ class IndexOpsMixin(object):
121123
122124
Assuming there are following attributes or properties and function.
123125
124-
:ivar _scol: Spark Column instance
125-
:type _scol: pyspark.Column
126126
:ivar _kdf: Parent's Koalas DataFrame
127127
:type _kdf: ks.DataFrame
128-
129-
:ivar spark_type: Spark data type
130-
:type spark_type: spark.types.DataType
128+
:ivar spark: Spark-related features
129+
:type spark: SparkIndexOpsMethods
131130
"""
132131

133132
def __init__(self, internal: InternalFrame, kdf):
@@ -136,47 +135,50 @@ def __init__(self, internal: InternalFrame, kdf):
136135
self._internal = internal # type: InternalFrame
137136
self._kdf = kdf
138137

138+
spark = CachedAccessor("spark", SparkIndexOpsMethods)
139+
139140
@property
140141
def spark_column(self):
141-
"""
142-
Spark Column object representing the Series/Index.
142+
warnings.warn(
143+
"Series.spark_column is deprecated as of Series.spark.column. "
144+
"Please use the API instead.",
145+
FutureWarning,
146+
)
147+
return self.spark.column
143148

144-
.. note:: This Spark Column object is strictly stick to its base DataFrame the Series/Index
145-
was derived from.
146-
"""
147-
return self._internal.spark_column
149+
spark_column.__doc__ = SparkIndexOpsMethods.column.__doc__
148150

149151
# arithmetic operators
150-
__neg__ = column_op(spark.Column.__neg__)
152+
__neg__ = column_op(Column.__neg__)
151153

152154
def __add__(self, other):
153-
if isinstance(self.spark_type, StringType):
155+
if isinstance(self.spark.type, StringType):
154156
# Concatenate string columns
155-
if isinstance(other, IndexOpsMixin) and isinstance(other.spark_type, StringType):
157+
if isinstance(other, IndexOpsMixin) and isinstance(other.spark.type, StringType):
156158
return column_op(F.concat)(self, other)
157159
# Handle df['col'] + 'literal'
158160
elif isinstance(other, str):
159161
return column_op(F.concat)(self, F.lit(other))
160162
else:
161163
raise TypeError("string addition can only be applied to string series or literals.")
162164
else:
163-
return column_op(spark.Column.__add__)(self, other)
165+
return column_op(Column.__add__)(self, other)
164166

165167
def __sub__(self, other):
166168
# Note that timestamp subtraction casts arguments to integer. This is to mimic Pandas's
167169
# behaviors. Pandas returns 'timedelta64[ns]' from 'datetime64[ns]'s subtraction.
168-
if isinstance(other, IndexOpsMixin) and isinstance(self.spark_type, TimestampType):
169-
if not isinstance(other.spark_type, TimestampType):
170+
if isinstance(other, IndexOpsMixin) and isinstance(self.spark.type, TimestampType):
171+
if not isinstance(other.spark.type, TimestampType):
170172
raise TypeError("datetime subtraction can only be applied to datetime series.")
171173
return self.astype("bigint") - other.astype("bigint")
172-
elif isinstance(other, IndexOpsMixin) and isinstance(self.spark_type, DateType):
173-
if not isinstance(other.spark_type, DateType):
174+
elif isinstance(other, IndexOpsMixin) and isinstance(self.spark.type, DateType):
175+
if not isinstance(other.spark.type, DateType):
174176
raise TypeError("date subtraction can only be applied to date series.")
175177
return column_op(F.datediff)(self, other)
176178
else:
177-
return column_op(spark.Column.__sub__)(self, other)
179+
return column_op(Column.__sub__)(self, other)
178180

179-
__mul__ = column_op(spark.Column.__mul__)
181+
__mul__ = column_op(Column.__mul__)
180182

181183
def __truediv__(self, other):
182184
"""
@@ -213,13 +215,13 @@ def mod(left, right):
213215

214216
def __radd__(self, other):
215217
# Handle 'literal' + df['col']
216-
if isinstance(self.spark_type, StringType) and isinstance(other, str):
217-
return self._with_new_scol(F.concat(F.lit(other), self.spark_column))
218+
if isinstance(self.spark.type, StringType) and isinstance(other, str):
219+
return self._with_new_scol(F.concat(F.lit(other), self.spark.column))
218220
else:
219-
return column_op(spark.Column.__radd__)(self, other)
221+
return column_op(Column.__radd__)(self, other)
220222

221-
__rsub__ = column_op(spark.Column.__rsub__)
222-
__rmul__ = column_op(spark.Column.__rmul__)
223+
__rsub__ = column_op(Column.__rsub__)
224+
__rmul__ = column_op(Column.__rmul__)
223225

224226
def __rtruediv__(self, other):
225227
def rtruediv(left, right):
@@ -274,24 +276,24 @@ def rmod(left, right):
274276

275277
return column_op(rmod)(self, other)
276278

277-
__pow__ = column_op(spark.Column.__pow__)
278-
__rpow__ = column_op(spark.Column.__rpow__)
279+
__pow__ = column_op(Column.__pow__)
280+
__rpow__ = column_op(Column.__rpow__)
279281

280282
# comparison operators
281-
__eq__ = column_op(spark.Column.__eq__)
282-
__ne__ = column_op(spark.Column.__ne__)
283-
__lt__ = column_op(spark.Column.__lt__)
284-
__le__ = column_op(spark.Column.__le__)
285-
__ge__ = column_op(spark.Column.__ge__)
286-
__gt__ = column_op(spark.Column.__gt__)
283+
__eq__ = column_op(Column.__eq__)
284+
__ne__ = column_op(Column.__ne__)
285+
__lt__ = column_op(Column.__lt__)
286+
__le__ = column_op(Column.__le__)
287+
__ge__ = column_op(Column.__ge__)
288+
__gt__ = column_op(Column.__gt__)
287289

288290
# `and`, `or`, `not` cannot be overloaded in Python,
289291
# so use bitwise operators as boolean operators
290-
__and__ = column_op(spark.Column.__and__)
291-
__or__ = column_op(spark.Column.__or__)
292-
__invert__ = column_op(spark.Column.__invert__)
293-
__rand__ = column_op(spark.Column.__rand__)
294-
__ror__ = column_op(spark.Column.__ror__)
292+
__and__ = column_op(Column.__and__)
293+
__or__ = column_op(Column.__or__)
294+
__invert__ = column_op(Column.__invert__)
295+
__rand__ = column_op(Column.__rand__)
296+
__ror__ = column_op(Column.__ror__)
295297

296298
# NDArray Compat
297299
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any):
@@ -333,7 +335,7 @@ def dtype(self):
333335
>>> s.rename("a").to_frame().set_index("a").index.dtype
334336
dtype('<M8[ns]')
335337
"""
336-
return spark_type_to_pandas_dtype(self.spark_type)
338+
return spark_type_to_pandas_dtype(self.spark.type)
337339

338340
@property
339341
def empty(self):
@@ -371,8 +373,8 @@ def hasnans(self):
371373
>>> ks.Series([1, 2, 3]).rename("a").to_frame().set_index("a").index.hasnans
372374
False
373375
"""
374-
sdf = self._internal._sdf.select(self.spark_column)
375-
col = self.spark_column
376+
sdf = self._internal._sdf.select(self.spark.column)
377+
col = self.spark.column
376378

377379
ret = sdf.select(F.max(col.isNull() | F.isnan(col))).collect()[0][0]
378380
return ret
@@ -552,7 +554,7 @@ def _is_monotonic(self, order):
552554
"__partition_id"
553555
), # Make sure we use the same partition id in the whole job.
554556
F.col(NATURAL_ORDER_COLUMN_NAME),
555-
self.spark_column.alias("__origin"),
557+
self.spark.column.alias("__origin"),
556558
)
557559
.select(
558560
F.col("__partition_id"),
@@ -670,7 +672,7 @@ def astype(self, dtype):
670672
spark_type = as_spark_type(dtype)
671673
if not spark_type:
672674
raise ValueError("Type {} not understood".format(dtype))
673-
return self._with_new_scol(self.spark_column.cast(spark_type))
675+
return self._with_new_scol(self.spark.column.cast(spark_type))
674676

675677
def isin(self, values):
676678
"""
@@ -722,7 +724,7 @@ def isin(self, values):
722724
" to isin(), you passed a [{values_type}]".format(values_type=type(values).__name__)
723725
)
724726

725-
return self._with_new_scol(self.spark_column.isin(list(values))).rename(self.name)
727+
return self._with_new_scol(self.spark.column.isin(list(values))).rename(self.name)
726728

727729
def isnull(self):
728730
"""
@@ -757,10 +759,10 @@ def isnull(self):
757759
raise NotImplementedError("isna is not defined for MultiIndex")
758760
if isinstance(self.spark_type, (FloatType, DoubleType)):
759761
return self._with_new_scol(
760-
self.spark_column.isNull() | F.isnan(self.spark_column)
762+
self.spark.column.isNull() | F.isnan(self.spark.column)
761763
).rename(self.name)
762764
else:
763-
return self._with_new_scol(self.spark_column.isNull()).rename(self.name)
765+
return self._with_new_scol(self.spark.column.isNull()).rename(self.name)
764766

765767
isna = isnull
766768

@@ -856,7 +858,7 @@ def all(self, axis: Union[int, str] = 0) -> bool:
856858
if axis != 0:
857859
raise NotImplementedError('axis should be either 0 or "index" currently.')
858860

859-
sdf = self._internal._sdf.select(self.spark_column)
861+
sdf = self._internal._sdf.select(self.spark.column)
860862
col = scol_for(sdf, sdf.columns[0])
861863

862864
# Note that we're ignoring `None`s here for now.
@@ -919,7 +921,7 @@ def any(self, axis: Union[int, str] = 0) -> bool:
919921
if axis != 0:
920922
raise NotImplementedError('axis should be either 0 or "index" currently.')
921923

922-
sdf = self._internal._sdf.select(self.spark_column)
924+
sdf = self._internal._sdf.select(self.spark.column)
923925
col = scol_for(sdf, sdf.columns[0])
924926

925927
# Note that we're ignoring `None`s here for now.
@@ -986,7 +988,7 @@ def _shift(self, periods, fill_value, part_cols=()):
986988
if not isinstance(periods, int):
987989
raise ValueError("periods should be an int; however, got [%s]" % type(periods))
988990

989-
col = self.spark_column
991+
col = self.spark.column
990992
window = (
991993
Window.partitionBy(*part_cols)
992994
.orderBy(NATURAL_ORDER_COLUMN_NAME)
@@ -1152,9 +1154,9 @@ def value_counts(self, normalize=False, sort=True, ascending=False, bins=None, d
11521154
raise NotImplementedError("value_counts currently does not support bins")
11531155

11541156
if dropna:
1155-
sdf_dropna = self._internal._sdf.select(self.spark_column).dropna()
1157+
sdf_dropna = self._internal._sdf.select(self.spark.column).dropna()
11561158
else:
1157-
sdf_dropna = self._internal._sdf.select(self.spark_column)
1159+
sdf_dropna = self._internal._sdf.select(self.spark.column)
11581160
index_name = SPARK_DEFAULT_INDEX_NAME
11591161
column_name = self._internal.data_spark_column_names[0]
11601162
sdf = sdf_dropna.groupby(scol_for(sdf_dropna, column_name).alias(index_name)).count()
@@ -1244,12 +1246,12 @@ def _nunique(self, dropna=True, approx=False, rsd=0.05):
12441246
colname = self._internal.data_spark_column_names[0]
12451247
count_fn = partial(F.approx_count_distinct, rsd=rsd) if approx else F.countDistinct
12461248
if dropna:
1247-
return count_fn(self.spark_column).alias(colname)
1249+
return count_fn(self.spark.column).alias(colname)
12481250
else:
12491251
return (
1250-
count_fn(self.spark_column)
1252+
count_fn(self.spark.column)
12511253
+ F.when(
1252-
F.count(F.when(self.spark_column.isNull(), 1).otherwise(None)) >= 1, 1
1254+
F.count(F.when(self.spark.column.isNull(), 1).otherwise(None)) >= 1, 1
12531255
).otherwise(0)
12541256
).alias(colname)
12551257

databricks/koalas/datetimes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class DatetimeMethods(object):
3434
"""Date/Time methods for Koalas Series"""
3535

3636
def __init__(self, series: "ks.Series"):
37-
if not isinstance(series.spark_type, (DateType, TimestampType)):
38-
raise ValueError("Cannot call DatetimeMethods on type {}".format(series.spark_type))
37+
if not isinstance(series.spark.type, (DateType, TimestampType)):
38+
raise ValueError("Cannot call DatetimeMethods on type {}".format(series.spark.type))
3939
self._data = series
4040

4141
# Properties

0 commit comments

Comments
 (0)