diff --git a/elasticdl/python/elasticdl/feature_column/feature_column.py b/elasticdl/python/elasticdl/feature_column/feature_column.py index aa3e1fa28..7a2c9156a 100644 --- a/elasticdl/python/elasticdl/feature_column/feature_column.py +++ b/elasticdl/python/elasticdl/feature_column/feature_column.py @@ -1,6 +1,8 @@ import collections +import itertools import math +import tensorflow as tf from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_v2 as fc_lib from tensorflow.python.framework import tensor_shape @@ -185,3 +187,135 @@ def reset(self): @property def embedding_and_ids(self): return self._embedding_delegate.embedding_and_ids + + +def concat_column(categorical_columns): + if not isinstance(categorical_columns, list): + raise ValueError("categorical_columns should be a list") + + if not categorical_columns: + raise ValueError("categorical_columns shouldn't be empty") + + for column in categorical_columns: + if not isinstance(column, fc_lib.CategoricalColumn): + raise ValueError( + "Items of categorical_columns should be CategoricalColumn." + " Given:{}".format(column) + ) + + return ConcatColumn(categorical_columns=tuple(categorical_columns)) + + +class ConcatColumn( + fc_lib.CategoricalColumn, + fc_old._CategoricalColumn, + collections.namedtuple("ConcatColumn", ("categorical_columns")), +): + def __init__(self, **kwargs): + # Calculate the offset tensor + total_num_buckets = 0 + leaf_column_num_buckets = [] + for categorical_column in self.categorical_columns: + leaf_column_num_buckets.append(categorical_column.num_buckets) + total_num_buckets += categorical_column.num_buckets + self.accumulated_offsets = list( + itertools.accumulate([0] + leaf_column_num_buckets[:-1]) + ) + self.total_num_buckets = total_num_buckets + + @property + def _is_v2_column(self): + for categorical_column in self.categorical_columns: + if not categorical_column._is_v2_column: + return False + + return True + + @property + def name(self): + feature_names = [] + for categorical_column in self.categorical_columns: + feature_names.append(categorical_column.name) + + return "_C_".join(sorted(feature_names)) + + @property + def num_buckets(self): + return self.total_num_buckets + + @property + def _num_buckets(self): + return self.total_num_buckets + + def transform_feature(self, transformation_cache, state_manager): + feature_tensors = [] + for categorical_column in self.categorical_columns: + ids_and_weights = categorical_column.get_sparse_tensors( + transformation_cache, state_manager + ) + feature_tensors.append(ids_and_weights.id_tensor) + + feature_tensors_with_offset = [] + for index, offset in enumerate(self.accumulated_offsets): + feature_tensor = feature_tensors[index] + feature_tensor_with_offset = tf.SparseTensor( + indices=feature_tensor.indices, + values=tf.cast( + tf.add(feature_tensor.values, offset), tf.int64 + ), + dense_shape=feature_tensor.dense_shape, + ) + feature_tensors_with_offset.append(feature_tensor_with_offset) + + return tf.sparse.concat(axis=-1, sp_inputs=feature_tensors_with_offset) + + def get_sparse_tensors(self, transformation_cache, state_manager): + return fc_lib.CategoricalColumn.IdWeightPair( + transformation_cache.get(self, state_manager), None + ) + + @property + def parents(self): + return list(self.categorical_columns) + + @property + def parse_example_spec(self): + config = {} + for categorical_column in self.categorical_columns: + config.update(categorical_column.parse_example_spec) + + return config + + @property + def _parse_example_spec(self): + return self.parse_example_spec + + def get_config(self): + from tensorflow.python.feature_column.serialization import ( + serialize_feature_column, + ) # pylint: disable=g-import-not-at-top + + config = dict(zip(self._fields, self)) + config["categorical_columns"] = tuple( + [serialize_feature_column(fc) for fc in self.categorical_columns] + ) + + return config + + @classmethod + def from_config(cls, config, custom_objects=None, columns_by_name=None): + """See 'FeatureColumn` base class.""" + from tensorflow.python.feature_column.serialization import ( + deserialize_feature_column, + ) # pylint: disable=g-import-not-at-top + + fc_lib._check_config_keys(config, cls._fields) + kwargs = fc_lib._standardize_and_copy_config(config) + kwargs["categorical_columns"] = tuple( + [ + deserialize_feature_column(c, custom_objects, columns_by_name) + for c in config["categorical_columns"] + ] + ) + + return cls(**kwargs) diff --git a/elasticdl/python/tests/feature_column_test.py b/elasticdl/python/tests/feature_column_test.py index ba1c18627..073cec439 100644 --- a/elasticdl/python/tests/feature_column_test.py +++ b/elasticdl/python/tests/feature_column_test.py @@ -253,6 +253,35 @@ def _mock_gather_embedding(name, ids): np.isclose(grad_values.numpy(), expected_grads).all() ) + def test_concat_column(self): + user_id = tf.feature_column.categorical_column_with_identity( + "user_id", num_buckets=32 + ) + + item_id = tf.feature_column.categorical_column_with_identity( + "item_id", num_buckets=128 + ) + + item_id_user_id_concat = feature_column.concat_column( + [user_id, item_id] + ) + + concat_indicator = tf.feature_column.indicator_column( + item_id_user_id_concat + ) + + output = call_feature_columns( + [concat_indicator], {"user_id": [10, 20], "item_id": [1, 120]}, + ) + + expected_output = tf.one_hot(indices=[10, 20], depth=160) + tf.one_hot( + indices=[1 + 32, 120 + 32], depth=160 + ) + + self.assertTrue( + np.array_equal(output.numpy(), expected_output.numpy()) + ) + if __name__ == "__main__": unittest.main()