Skip to content

Commit 2f3e894

Browse files
WeichenXu123HyukjinKwon
authored andcommitted
Add missing method rename for koalas dataframe (#806)
Add missing method rename for koalas dataframe. Some limitation: * Do not support in-place operation. * Require the mapper function include return type hint. such as: ``` def f1(x) -> int: return x*10 ``` * When rename index labels, it is possible to raise SparkException instead of KeyError (Discussion: Could we get the nested exception "KeyError" and re-throw it ?)
1 parent aeccfb5 commit 2f3e894

File tree

3 files changed

+281
-2
lines changed

3 files changed

+281
-2
lines changed

databricks/koalas/frame.py

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from databricks.koalas.missing.frame import _MissingPandasLikeDataFrame
5353
from databricks.koalas.ml import corr
5454
from databricks.koalas.utils import column_index_level, scol_for
55-
from databricks.koalas.typedef import as_spark_type, as_python_type
55+
from databricks.koalas.typedef import _infer_return_type, as_spark_type, as_python_type
5656
from databricks.koalas.plot import KoalasFramePlotMethods
5757
from databricks.koalas.config import get_option
5858

@@ -6854,6 +6854,224 @@ def filter(self, items=None, like=None, regex=None, axis=None):
68546854
else:
68556855
raise TypeError("Must pass either `items`, `like`, or `regex`")
68566856

6857+
def rename(self,
6858+
mapper=None,
6859+
index=None,
6860+
columns=None,
6861+
axis='index',
6862+
inplace=False,
6863+
level=None,
6864+
errors='ignore'):
6865+
6866+
"""
6867+
Alter axes labels.
6868+
Function / dict values must be unique (1-to-1). Labels not contained in a dict / Series
6869+
will be left as-is. Extra labels listed don’t throw an error.
6870+
6871+
Parameters
6872+
----------
6873+
mapper : dict-like or function
6874+
Dict-like or functions transformations to apply to that axis’ values.
6875+
Use either `mapper` and `axis` to specify the axis to target with `mapper`, or `index`
6876+
and `columns`.
6877+
index : dict-like or function
6878+
Alternative to specifying axis ("mapper, axis=0" is equivalent to "index=mapper").
6879+
columns : dict-like or function
6880+
Alternative to specifying axis ("mapper, axis=1" is equivalent to "columns=mapper").
6881+
axis : int or str, default 'index'
6882+
Axis to target with mapper. Can be either the axis name ('index', 'columns') or
6883+
number (0, 1).
6884+
inplace : bool, default False
6885+
Whether to return a new DataFrame.
6886+
level : int or level name, default None
6887+
In case of a MultiIndex, only rename labels in the specified level.
6888+
errors : {'ignore', 'raise}, default 'ignore'
6889+
If 'raise', raise a `KeyError` when a dict-like `mapper`, `index`, or `columns`
6890+
contains labels that are not present in the Index being transformed. If 'ignore',
6891+
existing keys will be renamed and extra keys will be ignored.
6892+
6893+
Returns
6894+
-------
6895+
DataFrame with the renamed axis labels.
6896+
6897+
Raises:
6898+
-------
6899+
`KeyError`
6900+
If any of the labels is not found in the selected axis and "errors='raise'".
6901+
6902+
Examples
6903+
--------
6904+
>>> kdf1 = ks.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
6905+
>>> kdf1.rename(columns={"A": "a", "B": "c"}) # doctest: +NORMALIZE_WHITESPACE
6906+
a c
6907+
0 1 4
6908+
1 2 5
6909+
2 3 6
6910+
6911+
>>> kdf1.rename(index={1: 10, 2: 20}) # doctest: +NORMALIZE_WHITESPACE
6912+
A B
6913+
0 1 4
6914+
10 2 5
6915+
20 3 6
6916+
6917+
>>> def str_lower(s) -> str:
6918+
... return str.lower(s)
6919+
>>> kdf1.rename(str_lower, axis='columns') # doctest: +NORMALIZE_WHITESPACE
6920+
a b
6921+
0 1 4
6922+
1 2 5
6923+
2 3 6
6924+
6925+
>>> def mul10(x) -> int:
6926+
... return x * 10
6927+
>>> kdf1.rename(mul10, axis='index') # doctest: +NORMALIZE_WHITESPACE
6928+
A B
6929+
0 1 4
6930+
10 2 5
6931+
20 3 6
6932+
6933+
>>> idx = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B'), ('Y', 'C'), ('Y', 'D')])
6934+
>>> kdf2 = ks.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx)
6935+
>>> kdf2.rename(columns=str_lower, level=0) # doctest: +NORMALIZE_WHITESPACE
6936+
x y
6937+
A B C D
6938+
0 1 2 3 4
6939+
1 5 6 7 8
6940+
6941+
>>> kdf3 = ks.DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]], index=idx, columns=list('ab'))
6942+
>>> kdf3.rename(index=str_lower) # doctest: +NORMALIZE_WHITESPACE
6943+
a b
6944+
x a 1 2
6945+
b 3 4
6946+
y c 5 6
6947+
d 7 8
6948+
"""
6949+
6950+
def gen_mapper_fn(mapper):
6951+
if isinstance(mapper, dict):
6952+
if len(mapper) == 0:
6953+
if errors == 'raise':
6954+
raise KeyError('Index include label which is not in the `mapper`.')
6955+
else:
6956+
return DataFrame(self._internal)
6957+
6958+
type_set = set(map(lambda x: type(x), mapper.values()))
6959+
if len(type_set) > 1:
6960+
raise ValueError("Mapper dict should have the same value type.")
6961+
spark_return_type = as_spark_type(list(type_set)[0])
6962+
6963+
def mapper_fn(x):
6964+
if x in mapper:
6965+
return mapper[x]
6966+
else:
6967+
if errors == 'raise':
6968+
raise KeyError('Index include value which is not in the `mapper`')
6969+
return x
6970+
elif callable(mapper):
6971+
spark_return_type = _infer_return_type(mapper).tpe
6972+
6973+
def mapper_fn(x):
6974+
return mapper(x)
6975+
else:
6976+
raise ValueError("`mapper` or `index` or `columns` should be "
6977+
"either dict-like or function type.")
6978+
return mapper_fn, spark_return_type
6979+
6980+
index_mapper_fn = None
6981+
index_mapper_ret_stype = None
6982+
columns_mapper_fn = None
6983+
6984+
if mapper:
6985+
if axis == 'index' or axis == 0:
6986+
index_mapper_fn, index_mapper_ret_stype = gen_mapper_fn(mapper)
6987+
elif axis == 'columns' or axis == 1:
6988+
columns_mapper_fn, columns_mapper_ret_stype = gen_mapper_fn(mapper)
6989+
else:
6990+
raise ValueError("argument axis should be either the axis name "
6991+
"(‘index’, ‘columns’) or number (0, 1)")
6992+
else:
6993+
if index:
6994+
index_mapper_fn, index_mapper_ret_stype = gen_mapper_fn(index)
6995+
if columns:
6996+
columns_mapper_fn, _ = gen_mapper_fn(columns)
6997+
6998+
if not index and not columns:
6999+
raise ValueError("Either `index` or `columns` should be provided.")
7000+
7001+
internal = self._internal
7002+
if index_mapper_fn:
7003+
# rename index labels, if `level` is None, rename all index columns, otherwise only
7004+
# rename the corresponding level index.
7005+
# implement this by transform the underlying spark dataframe,
7006+
# Example:
7007+
# suppose the kdf index column in underlying spark dataframe is "index_0", "index_1",
7008+
# if rename level 0 index labels, will do:
7009+
# ``kdf._sdf.withColumn("index_0", mapper_fn_udf(col("index_0"))``
7010+
# if rename all index labels (`level` is None), then will do:
7011+
# ```
7012+
# kdf._sdf.withColumn("index_0", mapper_fn_udf(col("index_0"))
7013+
# .withColumn("index_1", mapper_fn_udf(col("index_1"))
7014+
# ```
7015+
7016+
index_columns = internal.index_columns
7017+
num_indices = len(index_columns)
7018+
if level:
7019+
if level < 0 or level >= num_indices:
7020+
raise ValueError("level should be an integer between [0, num_indices)")
7021+
7022+
def gen_new_index_column(level):
7023+
index_col_name = index_columns[level]
7024+
7025+
index_mapper_udf = pandas_udf(lambda s: s.map(index_mapper_fn),
7026+
returnType=index_mapper_ret_stype)
7027+
return index_mapper_udf(scol_for(internal.sdf, index_col_name))
7028+
7029+
sdf = internal.sdf
7030+
if level is None:
7031+
for i in range(num_indices):
7032+
sdf = sdf.withColumn(index_columns[i], gen_new_index_column(i))
7033+
else:
7034+
sdf = sdf.withColumn(index_columns[level], gen_new_index_column(level))
7035+
internal = internal.copy(sdf=sdf)
7036+
if columns_mapper_fn:
7037+
# rename column name.
7038+
# Will modify the `_internal._column_index` and transform underlying spark dataframe
7039+
# to the same column name with `_internal._column_index`.
7040+
if level:
7041+
if level < 0 or level >= internal.column_index_level:
7042+
raise ValueError("level should be an integer between [0, column_index_level)")
7043+
7044+
def gen_new_column_index_entry(column_index_entry):
7045+
if isinstance(column_index_entry, tuple):
7046+
if level is None:
7047+
# rename all level columns
7048+
return tuple(map(columns_mapper_fn, column_index_entry))
7049+
else:
7050+
# only rename specified level column
7051+
entry_list = list(column_index_entry)
7052+
entry_list[level] = columns_mapper_fn(entry_list[level])
7053+
return tuple(entry_list)
7054+
else:
7055+
return columns_mapper_fn(column_index_entry)
7056+
7057+
new_column_index = list(map(gen_new_column_index_entry, internal.column_index))
7058+
7059+
if internal.column_index_level == 1:
7060+
new_data_columns = [col[0] for col in new_column_index]
7061+
else:
7062+
new_data_columns = [str(col) for col in new_column_index]
7063+
new_data_scols = [scol_for(internal.sdf, old_col_name).alias(new_col_name)
7064+
for old_col_name, new_col_name
7065+
in zip(internal.data_columns, new_data_columns)]
7066+
sdf = internal.sdf.select(*(internal.index_scols + new_data_scols))
7067+
internal = internal.copy(sdf=sdf, column_index=new_column_index,
7068+
data_columns=new_data_columns)
7069+
if inplace:
7070+
self._internal = internal
7071+
return self
7072+
else:
7073+
return DataFrame(internal)
7074+
68577075
def _get_from_multiindex_column(self, key):
68587076
""" Select columns from multi-index columns.
68597077

databricks/koalas/missing/frame.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class _MissingPandasLikeDataFrame(object):
8080
quantile = unsupported_function('quantile')
8181
query = unsupported_function('query')
8282
reindex_like = unsupported_function('reindex_like')
83-
rename = unsupported_function('rename')
8483
rename_axis = unsupported_function('rename_axis')
8584
reorder_levels = unsupported_function('reorder_levels')
8685
resample = unsupported_function('resample')

databricks/koalas/tests/test_dataframe.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,68 @@ def test_rename_columns(self):
379379
self.assert_eq(kdf._internal.data_columns, ["('A', '0')", "('B', 1)"])
380380
self.assert_eq(kdf._internal.spark_df.columns, ["('A', '0')", "('B', 1)"])
381381

382+
def test_rename_dataframe(self):
383+
kdf1 = ks.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
384+
result_kdf = kdf1.rename(columns={"A": "a", "B": "b"})
385+
self.assert_eq(result_kdf.columns, pd.Index(['a', 'b']))
386+
387+
result_kdf = kdf1.rename(index={1: 10, 2: 20})
388+
self.assert_eq(result_kdf.index, pd.Index([0, 10, 20]))
389+
self.assertTrue(kdf1 is not result_kdf,
390+
"expect return new dataframe when inplace argument is False")
391+
392+
result_kdf2 = result_kdf.rename(index={1: 10, 2: 20}, inplace=True)
393+
self.assertTrue(result_kdf2 is result_kdf,
394+
"expect return the same dataframe when inplace argument is False")
395+
396+
def str_lower(s) -> str:
397+
return str.lower(s)
398+
399+
result_kdf = kdf1.rename(str_lower, axis='columns')
400+
self.assert_eq(result_kdf.columns, pd.Index(['a', 'b']))
401+
402+
def mul10(x) -> int:
403+
return x * 10
404+
405+
result_kdf = kdf1.rename(mul10, axis='index')
406+
self.assert_eq(result_kdf.index, pd.Index([0, 10, 20]))
407+
408+
result_kdf = kdf1.rename(columns=str_lower, index={1: 10, 2: 20})
409+
self.assert_eq(result_kdf.columns, pd.Index(['a', 'b']))
410+
self.assert_eq(result_kdf.index, pd.Index([0, 10, 20]))
411+
412+
idx = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B'), ('Y', 'C'), ('Y', 'D')])
413+
kdf2 = ks.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx)
414+
415+
result_kdf = kdf2.rename(columns=str_lower)
416+
self.assert_eq(result_kdf.columns,
417+
pd.MultiIndex.from_tuples([('x', 'a'), ('x', 'b'), ('y', 'c'), ('y', 'd')]))
418+
419+
result_kdf = kdf2.rename(columns=str_lower, level=0)
420+
self.assert_eq(result_kdf.columns,
421+
pd.MultiIndex.from_tuples([('x', 'A'), ('x', 'B'), ('y', 'C'), ('y', 'D')]))
422+
423+
result_kdf = kdf2.rename(columns=str_lower, level=1)
424+
self.assert_eq(result_kdf.columns,
425+
pd.MultiIndex.from_tuples([('X', 'a'), ('X', 'b'), ('Y', 'c'), ('Y', 'd')]))
426+
427+
kdf3 = ks.DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]], index=idx, columns=list('ab'))
428+
429+
# for spark 2.3, disable arrow optimization. Because koalas multi-index do not support
430+
# arrow optimization in spark 2.3.
431+
432+
result_kdf = kdf3.rename(index=str_lower)
433+
self.assert_eq(result_kdf.index,
434+
pd.MultiIndex.from_tuples([('x', 'a'), ('x', 'b'), ('y', 'c'), ('y', 'd')]))
435+
436+
result_kdf = kdf3.rename(index=str_lower, level=0)
437+
self.assert_eq(result_kdf.index,
438+
pd.MultiIndex.from_tuples([('x', 'A'), ('x', 'B'), ('y', 'C'), ('y', 'D')]))
439+
440+
result_kdf = kdf3.rename(index=str_lower, level=1)
441+
self.assert_eq(result_kdf.index,
442+
pd.MultiIndex.from_tuples([('X', 'a'), ('X', 'b'), ('Y', 'c'), ('Y', 'd')]))
443+
382444
def test_dot_in_column_name(self):
383445
self.assert_eq(
384446
ks.DataFrame(ks.range(1)._sdf.selectExpr("1 as `a.b`"))['a.b'],

0 commit comments

Comments
 (0)