@@ -4441,22 +4441,29 @@ def clip(self, lower: Union[float, int] = None, upper: Union[float, int] = None)
44414441
44424442 numeric_types = (DecimalType , DoubleType , FloatType , ByteType , IntegerType , LongType ,
44434443 ShortType )
4444- numeric_columns = [(c , self ._internal .scol_for (c )) for c in self .columns
4445- if isinstance (self ._internal .spark_type_for (c ), numeric_types )]
4444+ numeric_columns = [(idx , self ._internal .scol_for (idx ))
4445+ for idx in self ._internal .column_index
4446+ if isinstance (self ._internal .spark_type_for (idx ), numeric_types )]
44464447
44474448 if lower is not None :
4448- numeric_columns = [(c , F .when (scol < lower , lower ).otherwise (scol ).alias (c ))
4449- for c , scol in numeric_columns ]
4449+ numeric_columns = [(idx , (F .when (scol < lower , lower ).otherwise (scol )
4450+ .alias (name_like_string (idx ))))
4451+ for idx , scol in numeric_columns ]
44504452 if upper is not None :
4451- numeric_columns = [(c , F .when (scol > upper , upper ).otherwise (scol ).alias (c ))
4452- for c , scol in numeric_columns ]
4453+ numeric_columns = [(idx , (F .when (scol > upper , upper ).otherwise (scol )
4454+ .alias (name_like_string (idx ))))
4455+ for idx , scol in numeric_columns ]
44534456
4454- nonnumeric_columns = [self . _internal . scol_for ( c ) for c in self . columns
4455- if not isinstance ( self . _internal . spark_type_for ( c ), numeric_types ) ]
4457+ column_index = [idx for idx , _ in numeric_columns ]
4458+ column_scols = [ scol for _ , scol in numeric_columns ]
44564459
4457- sdf = self ._sdf .select ([scol for _ , scol in numeric_columns ] + nonnumeric_columns )
4460+ for idx in self ._internal .column_index :
4461+ if not isinstance (self ._internal .spark_type_for (idx ), numeric_types ):
4462+ column_index .append (idx )
4463+ column_scols .append (self ._internal .scol_for (idx ))
44584464
4459- return ks .DataFrame (sdf )[list (self .columns )]
4465+ internal = self ._internal .with_new_columns (column_scols , column_index = column_index )
4466+ return DataFrame (internal )[list (self .columns )]
44604467
44614468 def head (self , n = 5 ):
44624469 """
0 commit comments