Skip to content

Add concat_column in ElasticDL feature column #1719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions elasticdl/python/elasticdl/feature_column/feature_column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import itertools
import math

import tensorflow as tf
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
from tensorflow.python.framework import tensor_shape
Expand Down Expand Up @@ -185,3 +187,135 @@ def reset(self):
@property
def embedding_and_ids(self):
return self._embedding_delegate.embedding_and_ids


def concat_column(categorical_columns):
if not isinstance(categorical_columns, list):
raise ValueError("categorical_columns should be a list")

if not categorical_columns:
raise ValueError("categorical_columns shouldn't be empty")

for column in categorical_columns:
if not isinstance(column, fc_lib.CategoricalColumn):
raise ValueError(
"Items of categorical_columns should be CategoricalColumn."
" Given:{}".format(column)
)

return ConcatColumn(categorical_columns=tuple(categorical_columns))


class ConcatColumn(
fc_lib.CategoricalColumn,
fc_old._CategoricalColumn,
collections.namedtuple("ConcatColumn", ("categorical_columns")),
):
def __init__(self, **kwargs):
# Calculate the offset tensor
total_num_buckets = 0
leaf_column_num_buckets = []
for categorical_column in self.categorical_columns:
leaf_column_num_buckets.append(categorical_column.num_buckets)
total_num_buckets += categorical_column.num_buckets
self.accumulated_offsets = list(
itertools.accumulate([0] + leaf_column_num_buckets[:-1])
)
self.total_num_buckets = total_num_buckets

@property
def _is_v2_column(self):
for categorical_column in self.categorical_columns:
if not categorical_column._is_v2_column:
return False

return True

@property
def name(self):
feature_names = []
for categorical_column in self.categorical_columns:
feature_names.append(categorical_column.name)

return "_C_".join(sorted(feature_names))

@property
def num_buckets(self):
return self.total_num_buckets

@property
def _num_buckets(self):
return self.total_num_buckets

def transform_feature(self, transformation_cache, state_manager):
feature_tensors = []
for categorical_column in self.categorical_columns:
ids_and_weights = categorical_column.get_sparse_tensors(
transformation_cache, state_manager
)
feature_tensors.append(ids_and_weights.id_tensor)

feature_tensors_with_offset = []
for index, offset in enumerate(self.accumulated_offsets):
feature_tensor = feature_tensors[index]
feature_tensor_with_offset = tf.SparseTensor(
indices=feature_tensor.indices,
values=tf.cast(
tf.add(feature_tensor.values, offset), tf.int64
),
dense_shape=feature_tensor.dense_shape,
)
feature_tensors_with_offset.append(feature_tensor_with_offset)

return tf.sparse.concat(axis=-1, sp_inputs=feature_tensors_with_offset)

def get_sparse_tensors(self, transformation_cache, state_manager):
return fc_lib.CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None
)

@property
def parents(self):
return list(self.categorical_columns)

@property
def parse_example_spec(self):
config = {}
for categorical_column in self.categorical_columns:
config.update(categorical_column.parse_example_spec)

return config

@property
def _parse_example_spec(self):
return self.parse_example_spec

def get_config(self):
from tensorflow.python.feature_column.serialization import (
serialize_feature_column,
) # pylint: disable=g-import-not-at-top

config = dict(zip(self._fields, self))
config["categorical_columns"] = tuple(
[serialize_feature_column(fc) for fc in self.categorical_columns]
)

return config

@classmethod
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import (
deserialize_feature_column,
) # pylint: disable=g-import-not-at-top

fc_lib._check_config_keys(config, cls._fields)
kwargs = fc_lib._standardize_and_copy_config(config)
kwargs["categorical_columns"] = tuple(
[
deserialize_feature_column(c, custom_objects, columns_by_name)
for c in config["categorical_columns"]
]
)

return cls(**kwargs)
29 changes: 29 additions & 0 deletions elasticdl/python/tests/feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,35 @@ def _mock_gather_embedding(name, ids):
np.isclose(grad_values.numpy(), expected_grads).all()
)

def test_concat_column(self):
user_id = tf.feature_column.categorical_column_with_identity(
"user_id", num_buckets=32
)

item_id = tf.feature_column.categorical_column_with_identity(
"item_id", num_buckets=128
)

item_id_user_id_concat = feature_column.concat_column(
[user_id, item_id]
)

concat_indicator = tf.feature_column.indicator_column(
item_id_user_id_concat
)

output = call_feature_columns(
[concat_indicator], {"user_id": [10, 20], "item_id": [1, 120]},
)

expected_output = tf.one_hot(indices=[10, 20], depth=160) + tf.one_hot(
indices=[1 + 32, 120 + 32], depth=160
)

self.assertTrue(
np.array_equal(output.numpy(), expected_output.numpy())
)


if __name__ == "__main__":
unittest.main()