Skip to content

Commit bc36a42

Browse files
Add concat_column in ElasticDL feature column (#1719)
* Add concat_column in elasticdl feature column * Update the output type of concat_column from tf.int32 to tf.int64 * Add the placeholder for concat_column test case * Add test case for concat_column * Add the syntax example for wide and deep model
1 parent 4b784ee commit bc36a42

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed

elasticdl/python/elasticdl/feature_column/feature_column.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import collections
2+
import itertools
23
import math
34

5+
import tensorflow as tf
46
from tensorflow.python.feature_column import feature_column as fc_old
57
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
68
from tensorflow.python.framework import tensor_shape
@@ -185,3 +187,135 @@ def reset(self):
185187
@property
186188
def embedding_and_ids(self):
187189
return self._embedding_delegate.embedding_and_ids
190+
191+
192+
def concat_column(categorical_columns):
193+
if not isinstance(categorical_columns, list):
194+
raise ValueError("categorical_columns should be a list")
195+
196+
if not categorical_columns:
197+
raise ValueError("categorical_columns shouldn't be empty")
198+
199+
for column in categorical_columns:
200+
if not isinstance(column, fc_lib.CategoricalColumn):
201+
raise ValueError(
202+
"Items of categorical_columns should be CategoricalColumn."
203+
" Given:{}".format(column)
204+
)
205+
206+
return ConcatColumn(categorical_columns=tuple(categorical_columns))
207+
208+
209+
class ConcatColumn(
210+
fc_lib.CategoricalColumn,
211+
fc_old._CategoricalColumn,
212+
collections.namedtuple("ConcatColumn", ("categorical_columns")),
213+
):
214+
def __init__(self, **kwargs):
215+
# Calculate the offset tensor
216+
total_num_buckets = 0
217+
leaf_column_num_buckets = []
218+
for categorical_column in self.categorical_columns:
219+
leaf_column_num_buckets.append(categorical_column.num_buckets)
220+
total_num_buckets += categorical_column.num_buckets
221+
self.accumulated_offsets = list(
222+
itertools.accumulate([0] + leaf_column_num_buckets[:-1])
223+
)
224+
self.total_num_buckets = total_num_buckets
225+
226+
@property
227+
def _is_v2_column(self):
228+
for categorical_column in self.categorical_columns:
229+
if not categorical_column._is_v2_column:
230+
return False
231+
232+
return True
233+
234+
@property
235+
def name(self):
236+
feature_names = []
237+
for categorical_column in self.categorical_columns:
238+
feature_names.append(categorical_column.name)
239+
240+
return "_C_".join(sorted(feature_names))
241+
242+
@property
243+
def num_buckets(self):
244+
return self.total_num_buckets
245+
246+
@property
247+
def _num_buckets(self):
248+
return self.total_num_buckets
249+
250+
def transform_feature(self, transformation_cache, state_manager):
251+
feature_tensors = []
252+
for categorical_column in self.categorical_columns:
253+
ids_and_weights = categorical_column.get_sparse_tensors(
254+
transformation_cache, state_manager
255+
)
256+
feature_tensors.append(ids_and_weights.id_tensor)
257+
258+
feature_tensors_with_offset = []
259+
for index, offset in enumerate(self.accumulated_offsets):
260+
feature_tensor = feature_tensors[index]
261+
feature_tensor_with_offset = tf.SparseTensor(
262+
indices=feature_tensor.indices,
263+
values=tf.cast(
264+
tf.add(feature_tensor.values, offset), tf.int64
265+
),
266+
dense_shape=feature_tensor.dense_shape,
267+
)
268+
feature_tensors_with_offset.append(feature_tensor_with_offset)
269+
270+
return tf.sparse.concat(axis=-1, sp_inputs=feature_tensors_with_offset)
271+
272+
def get_sparse_tensors(self, transformation_cache, state_manager):
273+
return fc_lib.CategoricalColumn.IdWeightPair(
274+
transformation_cache.get(self, state_manager), None
275+
)
276+
277+
@property
278+
def parents(self):
279+
return list(self.categorical_columns)
280+
281+
@property
282+
def parse_example_spec(self):
283+
config = {}
284+
for categorical_column in self.categorical_columns:
285+
config.update(categorical_column.parse_example_spec)
286+
287+
return config
288+
289+
@property
290+
def _parse_example_spec(self):
291+
return self.parse_example_spec
292+
293+
def get_config(self):
294+
from tensorflow.python.feature_column.serialization import (
295+
serialize_feature_column,
296+
) # pylint: disable=g-import-not-at-top
297+
298+
config = dict(zip(self._fields, self))
299+
config["categorical_columns"] = tuple(
300+
[serialize_feature_column(fc) for fc in self.categorical_columns]
301+
)
302+
303+
return config
304+
305+
@classmethod
306+
def from_config(cls, config, custom_objects=None, columns_by_name=None):
307+
"""See 'FeatureColumn` base class."""
308+
from tensorflow.python.feature_column.serialization import (
309+
deserialize_feature_column,
310+
) # pylint: disable=g-import-not-at-top
311+
312+
fc_lib._check_config_keys(config, cls._fields)
313+
kwargs = fc_lib._standardize_and_copy_config(config)
314+
kwargs["categorical_columns"] = tuple(
315+
[
316+
deserialize_feature_column(c, custom_objects, columns_by_name)
317+
for c in config["categorical_columns"]
318+
]
319+
)
320+
321+
return cls(**kwargs)

elasticdl/python/tests/feature_column_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,35 @@ def _mock_gather_embedding(name, ids):
253253
np.isclose(grad_values.numpy(), expected_grads).all()
254254
)
255255

256+
def test_concat_column(self):
257+
user_id = tf.feature_column.categorical_column_with_identity(
258+
"user_id", num_buckets=32
259+
)
260+
261+
item_id = tf.feature_column.categorical_column_with_identity(
262+
"item_id", num_buckets=128
263+
)
264+
265+
item_id_user_id_concat = feature_column.concat_column(
266+
[user_id, item_id]
267+
)
268+
269+
concat_indicator = tf.feature_column.indicator_column(
270+
item_id_user_id_concat
271+
)
272+
273+
output = call_feature_columns(
274+
[concat_indicator], {"user_id": [10, 20], "item_id": [1, 120]},
275+
)
276+
277+
expected_output = tf.one_hot(indices=[10, 20], depth=160) + tf.one_hot(
278+
indices=[1 + 32, 120 + 32], depth=160
279+
)
280+
281+
self.assertTrue(
282+
np.array_equal(output.numpy(), expected_output.numpy())
283+
)
284+
256285

257286
if __name__ == "__main__":
258287
unittest.main()

0 commit comments

Comments
 (0)