Skip to content

Commit 11d8dc6

Browse files
authored
Use Spark column equality instead of the column name. (#1524)
1 parent 1e32e0c commit 11d8dc6

File tree

1 file changed

+54
-50
lines changed

1 file changed

+54
-50
lines changed

databricks/koalas/internal.py

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def offset(id):
675675

676676
def spark_column_name_for(self, labels: Tuple[str, ...]) -> str:
677677
""" Return the actual Spark column name for the given column name. """
678-
return self._sdf.select(self.spark_column_for(labels)).columns[0]
678+
return self.spark_frame.select(self.spark_column_for(labels)).columns[0]
679679

680680
def spark_column_for(self, labels: Tuple[str, ...]):
681681
""" Return Spark Column for the given column name. """
@@ -687,7 +687,7 @@ def spark_column_for(self, labels: Tuple[str, ...]):
687687

688688
def spark_type_for(self, labels: Tuple[str, ...]) -> DataType:
689689
""" Return DataType for the given column name. """
690-
return self._sdf.select(self.spark_column_for(labels)).schema[0].dataType
690+
return self.spark_frame.select(self.spark_column_for(labels)).schema[0].dataType
691691

692692
@property
693693
def spark_frame(self) -> spark.DataFrame:
@@ -717,24 +717,21 @@ def index_spark_column_names(self) -> List[str]:
717717
@lazy_property
718718
def index_spark_columns(self) -> List[spark.Column]:
719719
""" Return Spark Columns for the managed index columns. """
720-
return [scol_for(self._sdf, column) for column in self.index_spark_column_names]
720+
return [scol_for(self.spark_frame, column) for column in self.index_spark_column_names]
721721

722722
@lazy_property
723723
def spark_column_names(self) -> List[str]:
724724
""" Return all the field names including index field names. """
725-
index_columns = set(self.index_spark_column_names)
726-
return self.index_spark_column_names + [
727-
column for column in self.data_spark_column_names if column not in index_columns
728-
]
725+
return self.spark_frame.select(self.spark_columns).columns
729726

730727
@lazy_property
731728
def spark_columns(self) -> List[spark.Column]:
732729
""" Return Spark Columns for the managed columns including index columns. """
733-
index_columns = set(self.index_spark_column_names)
734-
return self.index_spark_columns + [
735-
self.spark_column_for(label)
736-
for label in self.column_labels
737-
if self.spark_column_name_for(label) not in index_columns
730+
index_spark_columns = self.index_spark_columns
731+
return index_spark_columns + [
732+
spark_column
733+
for label, spark_column in zip(self.column_labels, self.data_spark_columns)
734+
if all(not spark_column._jc.equals(scol._jc) for scol in index_spark_columns)
738735
]
739736

740737
@property
@@ -769,28 +766,30 @@ def to_internal_spark_frame(self) -> spark.DataFrame:
769766
Return as Spark DataFrame. This contains index columns as well
770767
and should be only used for internal purposes.
771768
"""
772-
index_columns = set(self.index_spark_column_names)
769+
index_spark_columns = self.index_spark_columns
773770
data_columns = []
774-
for i, (column, label) in enumerate(zip(self.data_spark_column_names, self.column_labels)):
775-
if column not in index_columns:
776-
scol = self.spark_column_for(label)
771+
for i, (label, spark_column, column_name) in enumerate(
772+
zip(self.column_labels, self.data_spark_columns, self.data_spark_column_names)
773+
):
774+
if all(not spark_column._jc.equals(scol._jc) for scol in index_spark_columns):
777775
name = str(i) if label is None else name_like_string(label)
778-
if column != name:
779-
scol = scol.alias(name)
780-
data_columns.append(scol)
781-
return self._sdf.select(self.index_spark_columns + data_columns)
776+
if column_name != name:
777+
spark_column = spark_column.alias(name)
778+
data_columns.append(spark_column)
779+
return self.spark_frame.select(index_spark_columns + data_columns)
782780

783781
@lazy_property
784782
def to_external_spark_frame(self) -> spark.DataFrame:
785783
""" Return as new Spark DataFrame. """
786784
data_columns = []
787-
for i, (column, label) in enumerate(zip(self.data_spark_column_names, self.column_labels)):
788-
scol = self.spark_column_for(label)
785+
for i, (label, spark_column, column_name) in enumerate(
786+
zip(self.column_labels, self.data_spark_columns, self.data_spark_column_names)
787+
):
789788
name = str(i) if label is None else name_like_string(label)
790-
if column != name:
791-
scol = scol.alias(name)
792-
data_columns.append(scol)
793-
return self._sdf.select(data_columns)
789+
if column_name != name:
790+
spark_column = spark_column.alias(name)
791+
data_columns.append(spark_column)
792+
return self.spark_frame.select(data_columns)
794793

795794
@lazy_property
796795
def to_pandas_frame(self) -> pd.DataFrame:
@@ -802,25 +801,28 @@ def to_pandas_frame(self) -> pd.DataFrame:
802801
{field.name: spark_type_to_pandas_dtype(field.dataType) for field in sdf.schema}
803802
)
804803

805-
index_columns = self.index_spark_column_names
806-
if len(index_columns) > 0:
807-
append = False
808-
for index_field in index_columns:
809-
drop = index_field not in self.data_spark_column_names
810-
pdf = pdf.set_index(index_field, drop=drop, append=append)
811-
append = True
812-
pdf = pdf[
813-
[
814-
col
815-
if col in index_columns
816-
else str(i)
817-
if label is None
818-
else name_like_string(label)
819-
for i, (col, label) in enumerate(
820-
zip(self.data_spark_column_names, self.column_labels)
821-
)
822-
]
823-
]
804+
column_names = []
805+
for i, (label, spark_column, column_name) in enumerate(
806+
zip(self.column_labels, self.data_spark_columns, self.data_spark_column_names)
807+
):
808+
for index_spark_column_name, index_spark_column in zip(
809+
self.index_spark_column_names, self.index_spark_columns
810+
):
811+
if spark_column._jc.equals(index_spark_column._jc):
812+
column_names.append(index_spark_column_name)
813+
break
814+
else:
815+
name = str(i) if label is None else name_like_string(label)
816+
if column_name != name:
817+
column_name = name
818+
column_names.append(column_name)
819+
820+
append = False
821+
for index_field in self.index_spark_column_names:
822+
drop = index_field not in column_names
823+
pdf = pdf.set_index(index_field, drop=drop, append=append)
824+
append = True
825+
pdf = pdf[column_names]
824826

825827
if self.column_labels_level > 1:
826828
pdf.columns = pd.MultiIndex.from_tuples(self._column_labels)
@@ -910,7 +912,9 @@ def with_new_columns(
910912
if keep_order:
911913
hidden_columns.append(NATURAL_ORDER_COLUMN_NAME)
912914

913-
sdf = self._sdf.select(self.index_spark_columns + data_spark_columns + hidden_columns)
915+
sdf = self.spark_frame.select(
916+
self.index_spark_columns + data_spark_columns + hidden_columns
917+
)
914918

915919
if column_label_names is _NoValue:
916920
column_label_names = self._column_label_names
@@ -919,7 +923,7 @@ def with_new_columns(
919923
spark_frame=sdf,
920924
column_labels=column_labels,
921925
data_spark_columns=[
922-
scol_for(sdf, col) for col in self._sdf.select(data_spark_columns).columns
926+
scol_for(sdf, col) for col in self.spark_frame.select(data_spark_columns).columns
923927
],
924928
column_label_names=column_label_names,
925929
spark_column=None,
@@ -937,10 +941,10 @@ def with_filter(self, pred: Union[spark.Column, "Series"]):
937941
assert isinstance(pred.spark_type, BooleanType), pred.spark_type
938942
pred = pred.spark_column
939943
else:
940-
spark_type = self._sdf.select(pred).schema[0].dataType
944+
spark_type = self.spark_frame.select(pred).schema[0].dataType
941945
assert isinstance(spark_type, BooleanType), spark_type
942946

943-
return self.copy(spark_frame=self._sdf.drop(NATURAL_ORDER_COLUMN_NAME).filter(pred))
947+
return self.copy(spark_frame=self.spark_frame.drop(NATURAL_ORDER_COLUMN_NAME).filter(pred))
944948

945949
def copy(
946950
self,
@@ -962,7 +966,7 @@ def copy(
962966
:return: the copied immutable DataFrame.
963967
"""
964968
if spark_frame is _NoValue:
965-
spark_frame = self._sdf
969+
spark_frame = self.spark_frame
966970
if index_map is _NoValue:
967971
index_map = self._index_map
968972
if column_labels is _NoValue:

0 commit comments

Comments
 (0)