Skip to content

Commit d9fb2c1

Browse files
committed
Address comments
1 parent efd8d5e commit d9fb2c1

File tree

13 files changed

+55
-46
lines changed

13 files changed

+55
-46
lines changed

databricks/koalas/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def wrapper(self, *args):
109109
new_args = []
110110
for arg in args:
111111
# TODO: This is a quick hack to support NumPy type. We should revisit this.
112-
if isinstance(self.spark.type, LongType) and isinstance(arg, np.timedelta64):
112+
if isinstance(self.spark.data_type, LongType) and isinstance(arg, np.timedelta64):
113113
new_args.append(float(arg / np.timedelta64(1, "s")))
114114
else:
115115
new_args.append(arg)
@@ -152,9 +152,9 @@ def spark_column(self):
152152
__neg__ = column_op(Column.__neg__)
153153

154154
def __add__(self, other):
155-
if isinstance(self.spark.type, StringType):
155+
if isinstance(self.spark.data_type, StringType):
156156
# Concatenate string columns
157-
if isinstance(other, IndexOpsMixin) and isinstance(other.spark.type, StringType):
157+
if isinstance(other, IndexOpsMixin) and isinstance(other.spark.data_type, StringType):
158158
return column_op(F.concat)(self, other)
159159
# Handle df['col'] + 'literal'
160160
elif isinstance(other, str):
@@ -167,12 +167,12 @@ def __add__(self, other):
167167
def __sub__(self, other):
168168
# Note that timestamp subtraction casts arguments to integer. This is to mimic Pandas's
169169
# behaviors. Pandas returns 'timedelta64[ns]' from 'datetime64[ns]'s subtraction.
170-
if isinstance(other, IndexOpsMixin) and isinstance(self.spark.type, TimestampType):
171-
if not isinstance(other.spark.type, TimestampType):
170+
if isinstance(other, IndexOpsMixin) and isinstance(self.spark.data_type, TimestampType):
171+
if not isinstance(other.spark.data_type, TimestampType):
172172
raise TypeError("datetime subtraction can only be applied to datetime series.")
173173
return self.astype("bigint") - other.astype("bigint")
174-
elif isinstance(other, IndexOpsMixin) and isinstance(self.spark.type, DateType):
175-
if not isinstance(other.spark.type, DateType):
174+
elif isinstance(other, IndexOpsMixin) and isinstance(self.spark.data_type, DateType):
175+
if not isinstance(other.spark.data_type, DateType):
176176
raise TypeError("date subtraction can only be applied to date series.")
177177
return column_op(F.datediff)(self, other)
178178
else:
@@ -215,7 +215,7 @@ def mod(left, right):
215215

216216
def __radd__(self, other):
217217
# Handle 'literal' + df['col']
218-
if isinstance(self.spark.type, StringType) and isinstance(other, str):
218+
if isinstance(self.spark.data_type, StringType) and isinstance(other, str):
219219
return self._with_new_scol(F.concat(F.lit(other), self.spark.column))
220220
else:
221221
return column_op(Column.__radd__)(self, other)
@@ -335,7 +335,7 @@ def dtype(self):
335335
>>> s.rename("a").to_frame().set_index("a").index.dtype
336336
dtype('<M8[ns]')
337337
"""
338-
return spark_type_to_pandas_dtype(self.spark.type)
338+
return spark_type_to_pandas_dtype(self.spark.data_type)
339339

340340
@property
341341
def empty(self):

databricks/koalas/datetimes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ class DatetimeMethods(object):
3434
"""Date/Time methods for Koalas Series"""
3535

3636
def __init__(self, series: "ks.Series"):
37-
if not isinstance(series.spark.type, (DateType, TimestampType)):
38-
raise ValueError("Cannot call DatetimeMethods on type {}".format(series.spark.type))
37+
if not isinstance(series.spark.data_type, (DateType, TimestampType)):
38+
raise ValueError(
39+
"Cannot call DatetimeMethods on type {}".format(series.spark.data_type)
40+
)
3941
self._data = series
4042

4143
# Properties

databricks/koalas/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2742,7 +2742,7 @@ def pandas_frame_func(f):
27422742
kser = kdf_or_kser
27432743
pudf = pandas_udf(
27442744
func if should_by_pass else pandas_series_func(func),
2745-
returnType=kser.spark.type,
2745+
returnType=kser.spark.data_type,
27462746
functionType=PandasUDFType.SCALAR,
27472747
)
27482748
columns = self._internal.spark_columns

databricks/koalas/groupby.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,7 +2067,7 @@ def _reduce_for_stat_function(self, sfun, only_numeric, should_include_groupkeys
20672067
if len(agg_columns) > 0:
20682068
stat_exprs = []
20692069
for kser, c in zip(agg_columns, agg_columns_scols):
2070-
spark_type = kser.spark.type
2070+
spark_type = kser.spark.data_type
20712071
name = kser._internal.data_spark_column_names[0]
20722072
label = kser._internal.column_labels[0]
20732073
# TODO: we should have a function that takes dataframes and converts the numeric
@@ -2330,7 +2330,7 @@ def describe(self):
23302330
23312331
"""
23322332
for col in self._agg_columns:
2333-
if isinstance(col.spark.type, StringType):
2333+
if isinstance(col.spark.data_type, StringType):
23342334
raise NotImplementedError(
23352335
"DataFrameGroupBy.describe() doesn't support for string type for now"
23362336
)

databricks/koalas/indexes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def values(self):
420420
@property
421421
def spark_type(self):
422422
""" Returns the data type as defined by Spark, as a Spark DataType object."""
423-
return self.to_series().spark.type
423+
return self.to_series().spark.data_type
424424

425425
@property
426426
def has_duplicates(self) -> bool:

databricks/koalas/indexing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def _NotImplemented(description):
834834
def _select_rows_by_series(
835835
self, rows_sel: "Series"
836836
) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]:
837-
assert isinstance(rows_sel.spark.type, BooleanType), rows_sel.spark.type
837+
assert isinstance(rows_sel.spark.data_type, BooleanType), rows_sel.spark.data_type
838838
return rows_sel.spark.column, None, None
839839

840840
def _select_rows_by_spark_column(
@@ -855,7 +855,7 @@ def _select_rows_by_slice(
855855
sdf = self._internal.spark_frame
856856
index = self._kdf_or_kser.index
857857
index_column = index.to_series()
858-
index_data_type = index_column.spark.type
858+
index_data_type = index_column.spark.data_type
859859
start = rows_sel.start
860860
stop = rows_sel.stop
861861

@@ -912,7 +912,7 @@ def _select_rows_by_slice(
912912
return reduce(lambda x, y: x & y, cond), None, None
913913
else:
914914
index = self._kdf_or_kser.index
915-
index_data_type = [f.dataType for f in index.to_series().spark.type]
915+
index_data_type = [f.dataType for f in index.to_series().spark.data_type]
916916

917917
start = rows_sel.start
918918
if start is not None:
@@ -974,7 +974,7 @@ def _select_rows_by_iterable(
974974
return F.lit(False), None, None
975975
elif len(self._internal.index_spark_column_names) == 1:
976976
index_column = self._kdf_or_kser.index.to_series()
977-
index_data_type = index_column.spark.type
977+
index_data_type = index_column.spark.data_type
978978
if len(rows_sel) == 1:
979979
return (
980980
index_column.spark.column == F.lit(rows_sel[0]).cast(index_data_type),

databricks/koalas/internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def with_filter(self, pred: Union[spark.Column, "Series"]):
889889
from databricks.koalas.series import Series
890890

891891
if isinstance(pred, Series):
892-
assert isinstance(pred.spark.type, BooleanType), pred.spark.type
892+
assert isinstance(pred.spark.data_type, BooleanType), pred.spark.data_type
893893
pred = pred.spark.column
894894
else:
895895
spark_type = self.spark_frame.select(pred).schema[0].dataType

databricks/koalas/series.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,6 @@
303303
str_type = str
304304

305305

306-
class SparkMethods(object):
307-
pass
308-
309-
310306
class Series(Frame, IndexOpsMixin, Generic[T]):
311307
"""
312308
Koalas Series that corresponds to Pandas Series logically. This holds Spark Column
@@ -399,11 +395,11 @@ def axes(self):
399395
@property
400396
def spark_type(self):
401397
warnings.warn(
402-
"Series.spark_type is deprecated as of Series.spark.type. "
398+
"Series.spark_type is deprecated as of Series.spark.data_type. "
403399
"Please use the API instead.",
404400
FutureWarning,
405401
)
406-
return self.spark.type
402+
return self.spark.data_type
407403

408404
spark_type.__doc__ = SparkIndexOpsMethods.type.__doc__
409405

@@ -924,7 +920,7 @@ def map(self, arg):
924920
if isinstance(arg, dict):
925921
is_start = True
926922
# In case dictionary is empty.
927-
current = F.when(F.lit(False), F.lit(None).cast(self.spark.type))
923+
current = F.when(F.lit(False), F.lit(None).cast(self.spark.data_type))
928924

929925
for to_replace, value in arg.items():
930926
if is_start:
@@ -938,7 +934,7 @@ def map(self, arg):
938934
del arg[np._NoValue] # Remove in case it's set in defaultdict.
939935
current = current.otherwise(F.lit(tmp_val))
940936
else:
941-
current = current.otherwise(F.lit(None).cast(self.spark.type))
937+
current = current.otherwise(F.lit(None).cast(self.spark.data_type))
942938
return self._with_new_scol(current).rename(self.name)
943939
else:
944940
return self.apply(arg)
@@ -980,11 +976,11 @@ def astype(self, dtype) -> "Series":
980976
if not spark_type:
981977
raise ValueError("Type {} not understood".format(dtype))
982978
if isinstance(spark_type, BooleanType):
983-
if isinstance(self.spark.type, StringType):
979+
if isinstance(self.spark.data_type, StringType):
984980
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
985981
F.length(self.spark.column) > 0
986982
)
987-
elif isinstance(self.spark.type, (FloatType, DoubleType)):
983+
elif isinstance(self.spark.data_type, (FloatType, DoubleType)):
988984
scol = F.when(
989985
self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True)
990986
).otherwise(self.spark.column.cast(spark_type))
@@ -1745,7 +1741,7 @@ def clip(self, lower: Union[float, int] = None, upper: Union[float, int] = None)
17451741
if lower is None and upper is None:
17461742
return self
17471743

1748-
if isinstance(self.spark.type, NumericType):
1744+
if isinstance(self.spark.data_type, NumericType):
17491745
scol = self.spark.column
17501746
if lower is not None:
17511747
scol = F.when(scol < lower, lower).otherwise(scol)
@@ -2714,7 +2710,7 @@ def apply(self, func, args=(), **kwds):
27142710
pser = self.head(limit)._to_internal_pandas()
27152711
transformed = pser.apply(func, *args, **kwds)
27162712
kser = Series(transformed)
2717-
return self._transform_batch(apply_each, kser.spark.type)
2713+
return self._transform_batch(apply_each, kser.spark.data_type)
27182714
else:
27192715
sig_return = infer_return_type(func)
27202716
if not isinstance(sig_return, ScalarType):
@@ -3021,7 +3017,7 @@ def _transform_batch(self, func, return_schema):
30213017
pser = self.head(limit)._to_internal_pandas()
30223018
transformed = pser.transform(func)
30233019
kser = Series(transformed)
3024-
spark_return_type = kser.spark.type
3020+
spark_return_type = kser.spark.data_type
30253021
else:
30263022
spark_return_type = return_schema
30273023

@@ -4987,7 +4983,7 @@ def _cumprod(self, skipna, part_cols=()):
49874983
from pyspark.sql.functions import pandas_udf
49884984

49894985
def cumprod(scol):
4990-
@pandas_udf(returnType=self.spark.type)
4986+
@pandas_udf(returnType=self.spark.data_type)
49914987
def negative_check(s):
49924988
assert len(s) == 0 or ((s > 0) | (s.isnull())).all(), (
49934989
"values should be bigger than 0: %s" % s
@@ -5029,7 +5025,7 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=None):
50295025
raise ValueError("Series does not support columns axis.")
50305026
num_args = len(signature(sfun).parameters)
50315027
col_sdf = self.spark.column
5032-
col_type = self.spark.type
5028+
col_type = self.spark.data_type
50335029
if isinstance(col_type, BooleanType) and sfun.__name__ not in ("min", "max"):
50345030
# Stat functions cannot be used with boolean values by default
50355031
# Thus, cast to integer (true to 1 and false to 0)
@@ -5050,7 +5046,8 @@ def __len__(self):
50505046
def __getitem__(self, key):
50515047
try:
50525048
if (isinstance(key, slice) and any(type(n) == int for n in [key.start, key.stop])) or (
5053-
type(key) == int and not isinstance(self.index.spark.type, (IntegerType, LongType))
5049+
type(key) == int
5050+
and not isinstance(self.index.spark.data_type, (IntegerType, LongType))
50545051
):
50555052
# Seems like pandas Series always uses int as positional search when slicing
50565053
# with ints, searches based on index values when the value is int.
@@ -5104,10 +5101,10 @@ def __repr__(self):
51045101
return pser.to_string(name=self.name, dtype=self.dtype)
51055102

51065103
def __dir__(self):
5107-
if not isinstance(self.spark.type, StructType):
5104+
if not isinstance(self.spark.data_type, StructType):
51085105
fields = []
51095106
else:
5110-
fields = [f for f in self.spark.type.fieldNames() if " " not in f]
5107+
fields = [f for f in self.spark.data_type.fieldNames() if " " not in f]
51115108
return super(Series, self).__dir__() + fields
51125109

51135110
def __iter__(self):

databricks/koalas/spark.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, data: Union["IndexOpsMixin"]):
3939
self._data = data
4040

4141
@property
42-
def type(self):
42+
def data_type(self):
4343
""" Returns the data type as defined by Spark, as a Spark DataType object."""
4444
return self._data._internal.spark_type_for(self._data._internal.column_labels[0])
4545

@@ -110,7 +110,12 @@ def transform(self, func):
110110
"The output of the function [%s] should be of a "
111111
"pyspark.sql.Column; however, got [%s]." % (func, type(output))
112112
)
113-
return self._data._with_new_scol(scol=func(self._data.spark.column)).rename(self._data.name)
113+
new_ser = self._data._with_new_scol(scol=output).rename(self._data.name)
114+
# Trigger the resolution so it throws an exception if anything does wrong
115+
# within the function, for example,
116+
# `df1.a.spark.transform(lambda _: F.col("non-existent"))`.
117+
new_ser._internal.to_internal_spark_frame
118+
return new_ser
114119

115120

116121
class SparkFrameMethods(object):

databricks/koalas/strings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class StringMethods(object):
3535
"""String methods for Koalas Series"""
3636

3737
def __init__(self, series: "ks.Series"):
38-
if not isinstance(series.spark.type, (StringType, BinaryType, ArrayType)):
39-
raise ValueError("Cannot call StringMethods on type {}".format(series.spark.type))
38+
if not isinstance(series.spark.data_type, (StringType, BinaryType, ArrayType)):
39+
raise ValueError("Cannot call StringMethods on type {}".format(series.spark.data_type))
4040
self._data = series
4141
self.name = self._data.name
4242

@@ -1271,7 +1271,7 @@ def len(self) -> "ks.Series":
12711271
1 0
12721272
Name: 0, dtype: int64
12731273
"""
1274-
if isinstance(self._data.spark.type, (ArrayType, MapType)):
1274+
if isinstance(self._data.spark.data_type, (ArrayType, MapType)):
12751275
return column_op(lambda c: F.size(c).cast(LongType()))(self._data).alias(
12761276
self._data.name
12771277
)

0 commit comments

Comments
 (0)