@@ -675,7 +675,7 @@ def offset(id):
675675
676676 def spark_column_name_for (self , labels : Tuple [str , ...]) -> str :
677677 """ Return the actual Spark column name for the given column name. """
678- return self ._sdf .select (self .spark_column_for (labels )).columns [0 ]
678+ return self .spark_frame .select (self .spark_column_for (labels )).columns [0 ]
679679
680680 def spark_column_for (self , labels : Tuple [str , ...]):
681681 """ Return Spark Column for the given column name. """
@@ -687,7 +687,7 @@ def spark_column_for(self, labels: Tuple[str, ...]):
687687
688688 def spark_type_for (self , labels : Tuple [str , ...]) -> DataType :
689689 """ Return DataType for the given column name. """
690- return self ._sdf .select (self .spark_column_for (labels )).schema [0 ].dataType
690+ return self .spark_frame .select (self .spark_column_for (labels )).schema [0 ].dataType
691691
692692 @property
693693 def spark_frame (self ) -> spark .DataFrame :
@@ -717,24 +717,21 @@ def index_spark_column_names(self) -> List[str]:
717717 @lazy_property
718718 def index_spark_columns (self ) -> List [spark .Column ]:
719719 """ Return Spark Columns for the managed index columns. """
720- return [scol_for (self ._sdf , column ) for column in self .index_spark_column_names ]
720+ return [scol_for (self .spark_frame , column ) for column in self .index_spark_column_names ]
721721
722722 @lazy_property
723723 def spark_column_names (self ) -> List [str ]:
724724 """ Return all the field names including index field names. """
725- index_columns = set (self .index_spark_column_names )
726- return self .index_spark_column_names + [
727- column for column in self .data_spark_column_names if column not in index_columns
728- ]
725+ return self .spark_frame .select (self .spark_columns ).columns
729726
730727 @lazy_property
731728 def spark_columns (self ) -> List [spark .Column ]:
732729 """ Return Spark Columns for the managed columns including index columns. """
733- index_columns = set ( self .index_spark_column_names )
734- return self . index_spark_columns + [
735- self . spark_column_for ( label )
736- for label in self .column_labels
737- if self . spark_column_name_for ( label ) not in index_columns
730+ index_spark_columns = self .index_spark_columns
731+ return index_spark_columns + [
732+ spark_column
733+ for label , spark_column in zip ( self .column_labels , self . data_spark_columns )
734+ if all ( not spark_column . _jc . equals ( scol . _jc ) for scol in index_spark_columns )
738735 ]
739736
740737 @property
@@ -769,28 +766,30 @@ def to_internal_spark_frame(self) -> spark.DataFrame:
769766 Return as Spark DataFrame. This contains index columns as well
770767 and should be only used for internal purposes.
771768 """
772- index_columns = set ( self .index_spark_column_names )
769+ index_spark_columns = self .index_spark_columns
773770 data_columns = []
774- for i , (column , label ) in enumerate (zip (self .data_spark_column_names , self .column_labels )):
775- if column not in index_columns :
776- scol = self .spark_column_for (label )
771+ for i , (label , spark_column , column_name ) in enumerate (
772+ zip (self .column_labels , self .data_spark_columns , self .data_spark_column_names )
773+ ):
774+ if all (not spark_column ._jc .equals (scol ._jc ) for scol in index_spark_columns ):
777775 name = str (i ) if label is None else name_like_string (label )
778- if column != name :
779- scol = scol .alias (name )
780- data_columns .append (scol )
781- return self ._sdf .select (self . index_spark_columns + data_columns )
776+ if column_name != name :
777+ spark_column = spark_column .alias (name )
778+ data_columns .append (spark_column )
779+ return self .spark_frame .select (index_spark_columns + data_columns )
782780
783781 @lazy_property
784782 def to_external_spark_frame (self ) -> spark .DataFrame :
785783 """ Return as new Spark DataFrame. """
786784 data_columns = []
787- for i , (column , label ) in enumerate (zip (self .data_spark_column_names , self .column_labels )):
788- scol = self .spark_column_for (label )
785+ for i , (label , spark_column , column_name ) in enumerate (
786+ zip (self .column_labels , self .data_spark_columns , self .data_spark_column_names )
787+ ):
789788 name = str (i ) if label is None else name_like_string (label )
790- if column != name :
791- scol = scol .alias (name )
792- data_columns .append (scol )
793- return self ._sdf .select (data_columns )
789+ if column_name != name :
790+ spark_column = spark_column .alias (name )
791+ data_columns .append (spark_column )
792+ return self .spark_frame .select (data_columns )
794793
795794 @lazy_property
796795 def to_pandas_frame (self ) -> pd .DataFrame :
@@ -802,25 +801,28 @@ def to_pandas_frame(self) -> pd.DataFrame:
802801 {field .name : spark_type_to_pandas_dtype (field .dataType ) for field in sdf .schema }
803802 )
804803
805- index_columns = self .index_spark_column_names
806- if len (index_columns ) > 0 :
807- append = False
808- for index_field in index_columns :
809- drop = index_field not in self .data_spark_column_names
810- pdf = pdf .set_index (index_field , drop = drop , append = append )
811- append = True
812- pdf = pdf [
813- [
814- col
815- if col in index_columns
816- else str (i )
817- if label is None
818- else name_like_string (label )
819- for i , (col , label ) in enumerate (
820- zip (self .data_spark_column_names , self .column_labels )
821- )
822- ]
823- ]
804+ column_names = []
805+ for i , (label , spark_column , column_name ) in enumerate (
806+ zip (self .column_labels , self .data_spark_columns , self .data_spark_column_names )
807+ ):
808+ for index_spark_column_name , index_spark_column in zip (
809+ self .index_spark_column_names , self .index_spark_columns
810+ ):
811+ if spark_column ._jc .equals (index_spark_column ._jc ):
812+ column_names .append (index_spark_column_name )
813+ break
814+ else :
815+ name = str (i ) if label is None else name_like_string (label )
816+ if column_name != name :
817+ column_name = name
818+ column_names .append (column_name )
819+
820+ append = False
821+ for index_field in self .index_spark_column_names :
822+ drop = index_field not in column_names
823+ pdf = pdf .set_index (index_field , drop = drop , append = append )
824+ append = True
825+ pdf = pdf [column_names ]
824826
825827 if self .column_labels_level > 1 :
826828 pdf .columns = pd .MultiIndex .from_tuples (self ._column_labels )
@@ -910,7 +912,9 @@ def with_new_columns(
910912 if keep_order :
911913 hidden_columns .append (NATURAL_ORDER_COLUMN_NAME )
912914
913- sdf = self ._sdf .select (self .index_spark_columns + data_spark_columns + hidden_columns )
915+ sdf = self .spark_frame .select (
916+ self .index_spark_columns + data_spark_columns + hidden_columns
917+ )
914918
915919 if column_label_names is _NoValue :
916920 column_label_names = self ._column_label_names
@@ -919,7 +923,7 @@ def with_new_columns(
919923 spark_frame = sdf ,
920924 column_labels = column_labels ,
921925 data_spark_columns = [
922- scol_for (sdf , col ) for col in self ._sdf .select (data_spark_columns ).columns
926+ scol_for (sdf , col ) for col in self .spark_frame .select (data_spark_columns ).columns
923927 ],
924928 column_label_names = column_label_names ,
925929 spark_column = None ,
@@ -937,10 +941,10 @@ def with_filter(self, pred: Union[spark.Column, "Series"]):
937941 assert isinstance (pred .spark_type , BooleanType ), pred .spark_type
938942 pred = pred .spark_column
939943 else :
940- spark_type = self ._sdf .select (pred ).schema [0 ].dataType
944+ spark_type = self .spark_frame .select (pred ).schema [0 ].dataType
941945 assert isinstance (spark_type , BooleanType ), spark_type
942946
943- return self .copy (spark_frame = self ._sdf .drop (NATURAL_ORDER_COLUMN_NAME ).filter (pred ))
947+ return self .copy (spark_frame = self .spark_frame .drop (NATURAL_ORDER_COLUMN_NAME ).filter (pred ))
944948
945949 def copy (
946950 self ,
@@ -962,7 +966,7 @@ def copy(
962966 :return: the copied immutable DataFrame.
963967 """
964968 if spark_frame is _NoValue :
965- spark_frame = self ._sdf
969+ spark_frame = self .spark_frame
966970 if index_map is _NoValue :
967971 index_map = self ._index_map
968972 if column_labels is _NoValue :
0 commit comments