Skip to content

Commit 9d97ffc

Browse files
authored
Expose spark_column from Series/Index to make it easier to work with Spark Columns. (#1438)
This PR is exposing `spark_column` property representing the Series/Index for users who are familiar with Spark functions to make it easier to work with them. E.g.: ```py >>> kdf = ks.DataFrame({'a': [1.0, 1.0, 1.0, 2.0, 2.0, 2.0], 'b': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]}) >>> from pyspark.sql import functions as F >>> kdf['greatest'] = F.greatest(kdf.a.spark_column, kdf.b.spark_column) >>> kdf['least'] = F.least(kdf.a.spark_column, kdf.b.spark_column) >>> kdf a b greatest least 0 1.0 1.0 1.0 1.0 1 1.0 2.0 2.0 1.0 2 1.0 3.0 3.0 1.0 3 2.0 4.0 4.0 2.0 4 2.0 5.0 5.0 2.0 5 2.0 6.0 6.0 2.0 ```
1 parent 6cb1c02 commit 9d97ffc

File tree

13 files changed

+147
-114
lines changed

13 files changed

+147
-114
lines changed

databricks/koalas/base.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def wrapper(self, *args):
8383
cols = [arg for arg in args if isinstance(arg, IndexOpsMixin)]
8484
if all(self._kdf is col._kdf for col in cols):
8585
# Same DataFrame anchors
86-
args = [arg._scol if isinstance(arg, IndexOpsMixin) else arg for arg in args]
87-
scol = f(self._scol, *args)
88-
scol = booleanize_null(self._scol, scol, f)
86+
args = [arg.spark_column if isinstance(arg, IndexOpsMixin) else arg for arg in args]
87+
scol = f(self.spark_column, *args)
88+
scol = booleanize_null(self.spark_column, scol, f)
8989

9090
return self._with_new_scol(scol)
9191
else:
@@ -154,7 +154,13 @@ def __init__(self, internal: _InternalFrame, kdf):
154154
self._kdf = kdf
155155

156156
@property
157-
def _scol(self):
157+
def spark_column(self):
158+
"""
159+
Spark Column object representing the Series/Index.
160+
161+
.. note:: This Spark Column object is strictly stick to its base DataFrame the Series/Index
162+
was derived from.
163+
"""
158164
return self._internal.spark_column
159165

160166
# arithmetic operators
@@ -202,7 +208,7 @@ def mod(left, right):
202208
def __radd__(self, other):
203209
# Handle 'literal' + df['col']
204210
if isinstance(self.spark_type, StringType) and isinstance(other, str):
205-
return self._with_new_scol(F.concat(F.lit(other), self._scol))
211+
return self._with_new_scol(F.concat(F.lit(other), self.spark_column))
206212
else:
207213
return _column_op(spark.Column.__radd__)(self, other)
208214

@@ -336,8 +342,8 @@ def hasnans(self):
336342
>>> ks.Series([1, 2, 3]).rename("a").to_frame().set_index("a").index.hasnans
337343
False
338344
"""
339-
sdf = self._internal._sdf.select(self._scol)
340-
col = self._scol
345+
sdf = self._internal._sdf.select(self.spark_column)
346+
col = self.spark_column
341347

342348
ret = sdf.select(F.max(col.isNull() | F.isnan(col))).collect()[0][0]
343349
return ret
@@ -517,7 +523,7 @@ def _is_monotonic(self, order):
517523
"__partition_id"
518524
), # Make sure we use the same partition id in the whole job.
519525
F.col(NATURAL_ORDER_COLUMN_NAME),
520-
self._scol.alias("__origin"),
526+
self.spark_column.alias("__origin"),
521527
)
522528
.select(
523529
F.col("__partition_id"),
@@ -635,7 +641,7 @@ def astype(self, dtype):
635641
spark_type = as_spark_type(dtype)
636642
if not spark_type:
637643
raise ValueError("Type {} not understood".format(dtype))
638-
return self._with_new_scol(self._scol.cast(spark_type))
644+
return self._with_new_scol(self.spark_column.cast(spark_type))
639645

640646
def isin(self, values):
641647
"""
@@ -687,7 +693,7 @@ def isin(self, values):
687693
" to isin(), you passed a [{values_type}]".format(values_type=type(values).__name__)
688694
)
689695

690-
return self._with_new_scol(self._scol.isin(list(values))).rename(self.name)
696+
return self._with_new_scol(self.spark_column.isin(list(values))).rename(self.name)
691697

692698
def isnull(self):
693699
"""
@@ -721,9 +727,11 @@ def isnull(self):
721727
if isinstance(self, MultiIndex):
722728
raise NotImplementedError("isna is not defined for MultiIndex")
723729
if isinstance(self.spark_type, (FloatType, DoubleType)):
724-
return self._with_new_scol(self._scol.isNull() | F.isnan(self._scol)).rename(self.name)
730+
return self._with_new_scol(
731+
self.spark_column.isNull() | F.isnan(self.spark_column)
732+
).rename(self.name)
725733
else:
726-
return self._with_new_scol(self._scol.isNull()).rename(self.name)
734+
return self._with_new_scol(self.spark_column.isNull()).rename(self.name)
727735

728736
isna = isnull
729737

@@ -819,7 +827,7 @@ def all(self, axis: Union[int, str] = 0) -> bool:
819827
if axis != 0:
820828
raise NotImplementedError('axis should be either 0 or "index" currently.')
821829

822-
sdf = self._internal._sdf.select(self._scol)
830+
sdf = self._internal._sdf.select(self.spark_column)
823831
col = scol_for(sdf, sdf.columns[0])
824832

825833
# Note that we're ignoring `None`s here for now.
@@ -882,7 +890,7 @@ def any(self, axis: Union[int, str] = 0) -> bool:
882890
if axis != 0:
883891
raise NotImplementedError('axis should be either 0 or "index" currently.')
884892

885-
sdf = self._internal._sdf.select(self._scol)
893+
sdf = self._internal._sdf.select(self.spark_column)
886894
col = scol_for(sdf, sdf.columns[0])
887895

888896
# Note that we're ignoring `None`s here for now.
@@ -949,7 +957,7 @@ def _shift(self, periods, fill_value, part_cols=()):
949957
if not isinstance(periods, int):
950958
raise ValueError("periods should be an int; however, got [%s]" % type(periods))
951959

952-
col = self._scol
960+
col = self.spark_column
953961
window = (
954962
Window.partitionBy(*part_cols)
955963
.orderBy(NATURAL_ORDER_COLUMN_NAME)
@@ -1115,9 +1123,9 @@ def value_counts(self, normalize=False, sort=True, ascending=False, bins=None, d
11151123
raise NotImplementedError("value_counts currently does not support bins")
11161124

11171125
if dropna:
1118-
sdf_dropna = self._internal._sdf.select(self._scol).dropna()
1126+
sdf_dropna = self._internal._sdf.select(self.spark_column).dropna()
11191127
else:
1120-
sdf_dropna = self._internal._sdf.select(self._scol)
1128+
sdf_dropna = self._internal._sdf.select(self.spark_column)
11211129
index_name = SPARK_DEFAULT_INDEX_NAME
11221130
column_name = self._internal.data_spark_column_names[0]
11231131
sdf = sdf_dropna.groupby(scol_for(sdf_dropna, column_name).alias(index_name)).count()
@@ -1207,13 +1215,13 @@ def _nunique(self, dropna=True, approx=False, rsd=0.05):
12071215
colname = self._internal.data_spark_column_names[0]
12081216
count_fn = partial(F.approx_count_distinct, rsd=rsd) if approx else F.countDistinct
12091217
if dropna:
1210-
return count_fn(self._scol).alias(colname)
1218+
return count_fn(self.spark_column).alias(colname)
12111219
else:
12121220
return (
1213-
count_fn(self._scol)
1214-
+ F.when(F.count(F.when(self._scol.isNull(), 1).otherwise(None)) >= 1, 1).otherwise(
1215-
0
1216-
)
1221+
count_fn(self.spark_column)
1222+
+ F.when(
1223+
F.count(F.when(self.spark_column.isNull(), 1).otherwise(None)) >= 1, 1
1224+
).otherwise(0)
12171225
).alias(colname)
12181226

12191227
def take(self, indices):

databricks/koalas/frame.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,10 +2858,10 @@ def where(self, cond, other=np.nan):
28582858
for label in self._internal.column_labels:
28592859
data_spark_columns.append(
28602860
F.when(
2861-
kdf[tmp_cond_col_name(name_like_string(label))]._scol,
2861+
kdf[tmp_cond_col_name(name_like_string(label))].spark_column,
28622862
kdf._internal.spark_column_for(label),
28632863
)
2864-
.otherwise(kdf[tmp_other_col_name(name_like_string(label))]._scol)
2864+
.otherwise(kdf[tmp_other_col_name(name_like_string(label))].spark_column)
28652865
.alias(kdf._internal.spark_column_name_for(label))
28662866
)
28672867

@@ -3715,7 +3715,7 @@ def round(self, decimals=0):
37153715
def op(kser):
37163716
label = kser._internal.column_labels[0]
37173717
if label in decimals:
3718-
return F.round(kser._scol, decimals[label]).alias(
3718+
return F.round(kser.spark_column, decimals[label]).alias(
37193719
kser._internal.data_spark_column_names[0]
37203720
)
37213721
else:
@@ -4541,7 +4541,11 @@ def _assign(self, kwargs):
45414541

45424542
pairs = {
45434543
(k if isinstance(k, tuple) else (k,)): (
4544-
v._scol if isinstance(v, Series) else v if isinstance(v, spark.Column) else F.lit(v)
4544+
v.spark_column
4545+
if isinstance(v, Series)
4546+
else v
4547+
if isinstance(v, spark.Column)
4548+
else F.lit(v)
45454549
)
45464550
for k, v in kwargs.items()
45474551
}
@@ -4842,7 +4846,10 @@ def dropna(self, axis=0, how="any", thresh=None, subset=None, inplace=False):
48424846

48434847
cnt = reduce(
48444848
lambda x, y: x + y,
4845-
[F.when(self._kser_for(label).notna()._scol, 1).otherwise(0) for label in labels],
4849+
[
4850+
F.when(self._kser_for(label).notna().spark_column, 1).otherwise(0)
4851+
for label in labels
4852+
],
48464853
F.lit(0),
48474854
)
48484855
if thresh is not None:
@@ -5315,7 +5322,7 @@ def clip(self, lower: Union[float, int] = None, upper: Union[float, int] = None)
53155322

53165323
def op(kser):
53175324
if isinstance(kser.spark_type, numeric_types):
5318-
scol = kser._scol
5325+
scol = kser.spark_column
53195326
if lower is not None:
53205327
scol = F.when(scol < lower, lower).otherwise(scol)
53215328
if upper is not None:
@@ -6374,7 +6381,7 @@ def sort_values(
63746381
"The column %s is not unique. For a multi-index, the label must be a tuple "
63756382
"with elements corresponding to each level." % name_like_string(colname)
63766383
)
6377-
new_by.append(ser._scol)
6384+
new_by.append(ser.spark_column)
63786385

63796386
return self._sort(by=new_by, ascending=ascending, inplace=inplace, na_position=na_position)
63806387

@@ -8036,7 +8043,7 @@ def _reindex_index(self, index):
80368043
index_column = self._internal.index_spark_column_names[0]
80378044

80388045
kser = ks.Series(list(index))
8039-
labels = kser._internal._sdf.select(kser._scol.alias(index_column))
8046+
labels = kser._internal._sdf.select(kser.spark_column.alias(index_column))
80408047

80418048
joined_df = self._sdf.drop(NATURAL_ORDER_COLUMN_NAME).join(
80428049
labels, on=index_column, how="right"
@@ -9275,8 +9282,8 @@ def pct_change(self, periods=1):
92759282
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-periods, -periods)
92769283

92779284
def op(kser):
9278-
prev_row = F.lag(kser._scol, periods).over(window)
9279-
return ((kser._scol - prev_row) / prev_row).alias(
9285+
prev_row = F.lag(kser.spark_column, periods).over(window)
9286+
return ((kser.spark_column - prev_row) / prev_row).alias(
92809287
kser._internal.data_spark_column_names[0]
92819288
)
92829289

databricks/koalas/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def abs(self):
13951395
"""
13961396
# TODO: The first example above should not have "Name: 0".
13971397
return self._apply_series_op(
1398-
lambda kser: kser._with_new_scol(F.abs(kser._scol)).rename(kser.name)
1398+
lambda kser: kser._with_new_scol(F.abs(kser.spark_column)).rename(kser.name)
13991399
)
14001400

14011401
# TODO: by argument only support the grouping name and as_index only for now. Documentation

databricks/koalas/groupby.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,7 +1912,7 @@ def __init__(
19121912
):
19131913
self._kdf = kdf
19141914
self._groupkeys = by
1915-
self._groupkeys_scols = [s._scol for s in self._groupkeys]
1915+
self._groupkeys_scols = [s.spark_column for s in self._groupkeys]
19161916
self._as_index = as_index
19171917
self._should_drop_index = should_drop_index
19181918
self._have_agg_columns = True
@@ -1925,7 +1925,7 @@ def __init__(
19251925
]
19261926
self._have_agg_columns = False
19271927
self._agg_columns = [kdf[label] for label in agg_columns]
1928-
self._agg_columns_scols = [s._scol for s in self._agg_columns]
1928+
self._agg_columns_scols = [s.spark_column for s in self._agg_columns]
19291929

19301930
def __getattr__(self, item: str) -> Any:
19311931
if hasattr(_MissingPandasLikeDataFrameGroupBy, item):

databricks/koalas/indexes.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _summary(self, name=None):
142142
String with a summarized representation of the index
143143
"""
144144
head, tail, total_count = self._kdf._sdf.select(
145-
F.first(self._scol), F.last(self._scol), F.count(F.expr("*"))
145+
F.first(self.spark_column), F.last(self.spark_column), F.count(F.expr("*"))
146146
).first()
147147

148148
if total_count > 0:
@@ -440,7 +440,7 @@ def has_duplicates(self) -> bool:
440440
>>> kdf.index.has_duplicates
441441
True
442442
"""
443-
df = self._kdf._sdf.select(self._scol)
443+
df = self._kdf._sdf.select(self.spark_column)
444444
col = df.columns[0]
445445

446446
return df.select(F.count(col) != F.countDistinct(col)).first()[0]
@@ -554,7 +554,7 @@ def rename(
554554
)
555555

556556
idx = kdf.index
557-
idx._internal._scol = self._scol
557+
idx._internal = idx._internal.copy(spark_column=self.spark_column)
558558
if inplace:
559559
self._internal = idx._internal
560560
else:
@@ -664,7 +664,7 @@ def to_series(self, name: Union[str, Tuple[str, ...]] = None) -> Series:
664664
Name: 0, dtype: object
665665
"""
666666
kdf = self._kdf
667-
scol = self._scol
667+
scol = self.spark_column
668668
if name is not None:
669669
scol = scol.alias(name_like_string(name))
670670
column_labels = [None] if len(kdf._internal.index_map) > 1 else kdf._internal.index_names
@@ -731,7 +731,7 @@ def to_frame(self, index=True, name=None) -> DataFrame:
731731
name = self._internal.index_names[0]
732732
elif isinstance(name, str):
733733
name = (name,)
734-
scol = self._scol.alias(name_like_string(name))
734+
scol = self.spark_column.alias(name_like_string(name))
735735

736736
sdf = self._internal.spark_frame.select(scol, NATURAL_ORDER_COLUMN_NAME)
737737

@@ -1370,7 +1370,7 @@ def argmax(self):
13701370
>>> kidx.argmax()
13711371
4
13721372
"""
1373-
sdf = self._internal.spark_frame.select(self._scol)
1373+
sdf = self._internal.spark_frame.select(self.spark_column)
13741374
sequence_col = verify_temp_column_name(sdf, "__distributed_sequence_column__")
13751375
sdf = _InternalFrame.attach_distributed_sequence_column(sdf, column_name=sequence_col)
13761376
# spark_frame here looks like below
@@ -1388,7 +1388,7 @@ def argmax(self):
13881388
# | 1| 9|
13891389
# +-----------------+---------------+
13901390

1391-
return sdf.orderBy(self._scol.desc(), F.col(sequence_col).asc()).first()[0]
1391+
return sdf.orderBy(self.spark_column.desc(), F.col(sequence_col).asc()).first()[0]
13921392

13931393
def argmin(self):
13941394
"""
@@ -1411,11 +1411,11 @@ def argmin(self):
14111411
>>> kidx.argmin()
14121412
7
14131413
"""
1414-
sdf = self._internal.spark_frame.select(self._scol)
1414+
sdf = self._internal.spark_frame.select(self.spark_column)
14151415
sequence_col = verify_temp_column_name(sdf, "__distributed_sequence_column__")
14161416
sdf = _InternalFrame.attach_distributed_sequence_column(sdf, column_name=sequence_col)
14171417

1418-
return sdf.orderBy(self._scol.asc(), F.col(sequence_col).asc()).first()[0]
1418+
return sdf.orderBy(self.spark_column.asc(), F.col(sequence_col).asc()).first()[0]
14191419

14201420
def set_names(self, names, level=None, inplace=False):
14211421
"""
@@ -1689,9 +1689,9 @@ def asof(self, label):
16891689
"""
16901690
sdf = self._internal._sdf
16911691
if self.is_monotonic_increasing:
1692-
sdf = sdf.where(self._scol <= label).select(F.max(self._scol))
1692+
sdf = sdf.where(self.spark_column <= label).select(F.max(self.spark_column))
16931693
elif self.is_monotonic_decreasing:
1694-
sdf = sdf.where(self._scol >= label).select(F.min(self._scol))
1694+
sdf = sdf.where(self.spark_column >= label).select(F.min(self.spark_column))
16951695
else:
16961696
raise ValueError("index must be monotonic increasing or decreasing")
16971697
result = sdf.head()[0]
@@ -1780,7 +1780,11 @@ def __repr__(self):
17801780
if max_display_count is None:
17811781
return repr(self.to_pandas())
17821782

1783-
pindex = self._kdf.head(max_display_count + 1).index._with_new_scol(self._scol).to_pandas()
1783+
pindex = (
1784+
self._kdf.head(max_display_count + 1)
1785+
.index._with_new_scol(self.spark_column)
1786+
.to_pandas()
1787+
)
17841788

17851789
pindex_length = len(pindex)
17861790
repr_string = repr(pindex[:max_display_count])
@@ -2072,7 +2076,7 @@ def _is_monotonic(self, order):
20722076
return self._is_monotonic_decreasing().all()
20732077

20742078
def _is_monotonic_increasing(self):
2075-
scol = self._scol
2079+
scol = self.spark_column
20762080
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-1, -1)
20772081
prev = F.lag(scol, 1).over(window)
20782082

@@ -2108,7 +2112,7 @@ def _comparator_for_monotonic_decreasing(data_type):
21082112
return compare_null_first
21092113

21102114
def _is_monotonic_decreasing(self):
2111-
scol = self._scol
2115+
scol = self.spark_column
21122116
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-1, -1)
21132117
prev = F.lag(scol, 1).over(window)
21142118

0 commit comments

Comments
 (0)