|
52 | 52 | from databricks.koalas.missing.frame import _MissingPandasLikeDataFrame |
53 | 53 | from databricks.koalas.ml import corr |
54 | 54 | from databricks.koalas.utils import column_index_level, scol_for |
55 | | -from databricks.koalas.typedef import as_spark_type, as_python_type |
| 55 | +from databricks.koalas.typedef import _infer_return_type, as_spark_type, as_python_type |
56 | 56 | from databricks.koalas.plot import KoalasFramePlotMethods |
57 | 57 | from databricks.koalas.config import get_option |
58 | 58 |
|
@@ -6854,6 +6854,224 @@ def filter(self, items=None, like=None, regex=None, axis=None): |
6854 | 6854 | else: |
6855 | 6855 | raise TypeError("Must pass either `items`, `like`, or `regex`") |
6856 | 6856 |
|
| 6857 | + def rename(self, |
| 6858 | + mapper=None, |
| 6859 | + index=None, |
| 6860 | + columns=None, |
| 6861 | + axis='index', |
| 6862 | + inplace=False, |
| 6863 | + level=None, |
| 6864 | + errors='ignore'): |
| 6865 | + |
| 6866 | + """ |
| 6867 | + Alter axes labels. |
| 6868 | + Function / dict values must be unique (1-to-1). Labels not contained in a dict / Series |
| 6869 | + will be left as-is. Extra labels listed don’t throw an error. |
| 6870 | +
|
| 6871 | + Parameters |
| 6872 | + ---------- |
| 6873 | + mapper : dict-like or function |
| 6874 | + Dict-like or functions transformations to apply to that axis’ values. |
| 6875 | + Use either `mapper` and `axis` to specify the axis to target with `mapper`, or `index` |
| 6876 | + and `columns`. |
| 6877 | + index : dict-like or function |
| 6878 | + Alternative to specifying axis ("mapper, axis=0" is equivalent to "index=mapper"). |
| 6879 | + columns : dict-like or function |
| 6880 | + Alternative to specifying axis ("mapper, axis=1" is equivalent to "columns=mapper"). |
| 6881 | + axis : int or str, default 'index' |
| 6882 | + Axis to target with mapper. Can be either the axis name ('index', 'columns') or |
| 6883 | + number (0, 1). |
| 6884 | + inplace : bool, default False |
| 6885 | + Whether to return a new DataFrame. |
| 6886 | + level : int or level name, default None |
| 6887 | + In case of a MultiIndex, only rename labels in the specified level. |
| 6888 | + errors : {'ignore', 'raise}, default 'ignore' |
| 6889 | + If 'raise', raise a `KeyError` when a dict-like `mapper`, `index`, or `columns` |
| 6890 | + contains labels that are not present in the Index being transformed. If 'ignore', |
| 6891 | + existing keys will be renamed and extra keys will be ignored. |
| 6892 | +
|
| 6893 | + Returns |
| 6894 | + ------- |
| 6895 | + DataFrame with the renamed axis labels. |
| 6896 | +
|
| 6897 | + Raises: |
| 6898 | + ------- |
| 6899 | + `KeyError` |
| 6900 | + If any of the labels is not found in the selected axis and "errors='raise'". |
| 6901 | +
|
| 6902 | + Examples |
| 6903 | + -------- |
| 6904 | + >>> kdf1 = ks.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) |
| 6905 | + >>> kdf1.rename(columns={"A": "a", "B": "c"}) # doctest: +NORMALIZE_WHITESPACE |
| 6906 | + a c |
| 6907 | + 0 1 4 |
| 6908 | + 1 2 5 |
| 6909 | + 2 3 6 |
| 6910 | +
|
| 6911 | + >>> kdf1.rename(index={1: 10, 2: 20}) # doctest: +NORMALIZE_WHITESPACE |
| 6912 | + A B |
| 6913 | + 0 1 4 |
| 6914 | + 10 2 5 |
| 6915 | + 20 3 6 |
| 6916 | +
|
| 6917 | + >>> def str_lower(s) -> str: |
| 6918 | + ... return str.lower(s) |
| 6919 | + >>> kdf1.rename(str_lower, axis='columns') # doctest: +NORMALIZE_WHITESPACE |
| 6920 | + a b |
| 6921 | + 0 1 4 |
| 6922 | + 1 2 5 |
| 6923 | + 2 3 6 |
| 6924 | +
|
| 6925 | + >>> def mul10(x) -> int: |
| 6926 | + ... return x * 10 |
| 6927 | + >>> kdf1.rename(mul10, axis='index') # doctest: +NORMALIZE_WHITESPACE |
| 6928 | + A B |
| 6929 | + 0 1 4 |
| 6930 | + 10 2 5 |
| 6931 | + 20 3 6 |
| 6932 | +
|
| 6933 | + >>> idx = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B'), ('Y', 'C'), ('Y', 'D')]) |
| 6934 | + >>> kdf2 = ks.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx) |
| 6935 | + >>> kdf2.rename(columns=str_lower, level=0) # doctest: +NORMALIZE_WHITESPACE |
| 6936 | + x y |
| 6937 | + A B C D |
| 6938 | + 0 1 2 3 4 |
| 6939 | + 1 5 6 7 8 |
| 6940 | +
|
| 6941 | + >>> kdf3 = ks.DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]], index=idx, columns=list('ab')) |
| 6942 | + >>> kdf3.rename(index=str_lower) # doctest: +NORMALIZE_WHITESPACE |
| 6943 | + a b |
| 6944 | + x a 1 2 |
| 6945 | + b 3 4 |
| 6946 | + y c 5 6 |
| 6947 | + d 7 8 |
| 6948 | + """ |
| 6949 | + |
| 6950 | + def gen_mapper_fn(mapper): |
| 6951 | + if isinstance(mapper, dict): |
| 6952 | + if len(mapper) == 0: |
| 6953 | + if errors == 'raise': |
| 6954 | + raise KeyError('Index include label which is not in the `mapper`.') |
| 6955 | + else: |
| 6956 | + return DataFrame(self._internal) |
| 6957 | + |
| 6958 | + type_set = set(map(lambda x: type(x), mapper.values())) |
| 6959 | + if len(type_set) > 1: |
| 6960 | + raise ValueError("Mapper dict should have the same value type.") |
| 6961 | + spark_return_type = as_spark_type(list(type_set)[0]) |
| 6962 | + |
| 6963 | + def mapper_fn(x): |
| 6964 | + if x in mapper: |
| 6965 | + return mapper[x] |
| 6966 | + else: |
| 6967 | + if errors == 'raise': |
| 6968 | + raise KeyError('Index include value which is not in the `mapper`') |
| 6969 | + return x |
| 6970 | + elif callable(mapper): |
| 6971 | + spark_return_type = _infer_return_type(mapper).tpe |
| 6972 | + |
| 6973 | + def mapper_fn(x): |
| 6974 | + return mapper(x) |
| 6975 | + else: |
| 6976 | + raise ValueError("`mapper` or `index` or `columns` should be " |
| 6977 | + "either dict-like or function type.") |
| 6978 | + return mapper_fn, spark_return_type |
| 6979 | + |
| 6980 | + index_mapper_fn = None |
| 6981 | + index_mapper_ret_stype = None |
| 6982 | + columns_mapper_fn = None |
| 6983 | + |
| 6984 | + if mapper: |
| 6985 | + if axis == 'index' or axis == 0: |
| 6986 | + index_mapper_fn, index_mapper_ret_stype = gen_mapper_fn(mapper) |
| 6987 | + elif axis == 'columns' or axis == 1: |
| 6988 | + columns_mapper_fn, columns_mapper_ret_stype = gen_mapper_fn(mapper) |
| 6989 | + else: |
| 6990 | + raise ValueError("argument axis should be either the axis name " |
| 6991 | + "(‘index’, ‘columns’) or number (0, 1)") |
| 6992 | + else: |
| 6993 | + if index: |
| 6994 | + index_mapper_fn, index_mapper_ret_stype = gen_mapper_fn(index) |
| 6995 | + if columns: |
| 6996 | + columns_mapper_fn, _ = gen_mapper_fn(columns) |
| 6997 | + |
| 6998 | + if not index and not columns: |
| 6999 | + raise ValueError("Either `index` or `columns` should be provided.") |
| 7000 | + |
| 7001 | + internal = self._internal |
| 7002 | + if index_mapper_fn: |
| 7003 | + # rename index labels, if `level` is None, rename all index columns, otherwise only |
| 7004 | + # rename the corresponding level index. |
| 7005 | + # implement this by transform the underlying spark dataframe, |
| 7006 | + # Example: |
| 7007 | + # suppose the kdf index column in underlying spark dataframe is "index_0", "index_1", |
| 7008 | + # if rename level 0 index labels, will do: |
| 7009 | + # ``kdf._sdf.withColumn("index_0", mapper_fn_udf(col("index_0"))`` |
| 7010 | + # if rename all index labels (`level` is None), then will do: |
| 7011 | + # ``` |
| 7012 | + # kdf._sdf.withColumn("index_0", mapper_fn_udf(col("index_0")) |
| 7013 | + # .withColumn("index_1", mapper_fn_udf(col("index_1")) |
| 7014 | + # ``` |
| 7015 | + |
| 7016 | + index_columns = internal.index_columns |
| 7017 | + num_indices = len(index_columns) |
| 7018 | + if level: |
| 7019 | + if level < 0 or level >= num_indices: |
| 7020 | + raise ValueError("level should be an integer between [0, num_indices)") |
| 7021 | + |
| 7022 | + def gen_new_index_column(level): |
| 7023 | + index_col_name = index_columns[level] |
| 7024 | + |
| 7025 | + index_mapper_udf = pandas_udf(lambda s: s.map(index_mapper_fn), |
| 7026 | + returnType=index_mapper_ret_stype) |
| 7027 | + return index_mapper_udf(scol_for(internal.sdf, index_col_name)) |
| 7028 | + |
| 7029 | + sdf = internal.sdf |
| 7030 | + if level is None: |
| 7031 | + for i in range(num_indices): |
| 7032 | + sdf = sdf.withColumn(index_columns[i], gen_new_index_column(i)) |
| 7033 | + else: |
| 7034 | + sdf = sdf.withColumn(index_columns[level], gen_new_index_column(level)) |
| 7035 | + internal = internal.copy(sdf=sdf) |
| 7036 | + if columns_mapper_fn: |
| 7037 | + # rename column name. |
| 7038 | + # Will modify the `_internal._column_index` and transform underlying spark dataframe |
| 7039 | + # to the same column name with `_internal._column_index`. |
| 7040 | + if level: |
| 7041 | + if level < 0 or level >= internal.column_index_level: |
| 7042 | + raise ValueError("level should be an integer between [0, column_index_level)") |
| 7043 | + |
| 7044 | + def gen_new_column_index_entry(column_index_entry): |
| 7045 | + if isinstance(column_index_entry, tuple): |
| 7046 | + if level is None: |
| 7047 | + # rename all level columns |
| 7048 | + return tuple(map(columns_mapper_fn, column_index_entry)) |
| 7049 | + else: |
| 7050 | + # only rename specified level column |
| 7051 | + entry_list = list(column_index_entry) |
| 7052 | + entry_list[level] = columns_mapper_fn(entry_list[level]) |
| 7053 | + return tuple(entry_list) |
| 7054 | + else: |
| 7055 | + return columns_mapper_fn(column_index_entry) |
| 7056 | + |
| 7057 | + new_column_index = list(map(gen_new_column_index_entry, internal.column_index)) |
| 7058 | + |
| 7059 | + if internal.column_index_level == 1: |
| 7060 | + new_data_columns = [col[0] for col in new_column_index] |
| 7061 | + else: |
| 7062 | + new_data_columns = [str(col) for col in new_column_index] |
| 7063 | + new_data_scols = [scol_for(internal.sdf, old_col_name).alias(new_col_name) |
| 7064 | + for old_col_name, new_col_name |
| 7065 | + in zip(internal.data_columns, new_data_columns)] |
| 7066 | + sdf = internal.sdf.select(*(internal.index_scols + new_data_scols)) |
| 7067 | + internal = internal.copy(sdf=sdf, column_index=new_column_index, |
| 7068 | + data_columns=new_data_columns) |
| 7069 | + if inplace: |
| 7070 | + self._internal = internal |
| 7071 | + return self |
| 7072 | + else: |
| 7073 | + return DataFrame(internal) |
| 7074 | + |
6857 | 7075 | def _get_from_multiindex_column(self, key): |
6858 | 7076 | """ Select columns from multi-index columns. |
6859 | 7077 |
|
|
0 commit comments