Skip to content

Hashing preprocessing layer to convert categorical inputs to zero-based integer values #1859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Mar 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
be3c822
Add RoundIdentity layer to transform numeric values to integer ids
workingloong Mar 17, 2020
ea319da
Fix the docstring
workingloong Mar 17, 2020
69495b6
Config pre-commit with elasticdl_preprocessing
workingloong Mar 17, 2020
d082113
Merge branch 'develop' into add_round_identity
workingloong Mar 17, 2020
4a25f34
Fix the docstring by comments
workingloong Mar 19, 2020
a6789ad
Add hashing layer to preprocess categorical feature
workingloong Mar 19, 2020
7cfe9cd
Merge branch 'develop' into add_round_identity
workingloong Mar 19, 2020
9dca79c
Add a method to check ragged tensors equal
workingloong Mar 19, 2020
640c94f
Merge branch 'add_round_identity' into add_hashing
workingloong Mar 19, 2020
a724134
Add an unit test for hashing
workingloong Mar 19, 2020
72899b7
fix docstring
workingloong Mar 19, 2020
9061080
Add description about Keras hashing layer
workingloong Mar 19, 2020
1c39600
Remove unused import
workingloong Mar 19, 2020
2b81cd0
Fix example
workingloong Mar 19, 2020
fab901a
Fix example
workingloong Mar 19, 2020
b043ac5
fix docstring
workingloong Mar 19, 2020
0d8132c
Merge branch 'add_round_identity' into add_hashing
workingloong Mar 19, 2020
4a6bc26
Add hashing layer to convert catgorical inputs to integer values
workingloong Mar 19, 2020
c4b7655
Remove index url
workingloong Mar 19, 2020
6aa2270
Use TensorFlow public API
workingloong Mar 20, 2020
ab3376e
Merge branch 'develop' into add_hashing
workingloong Mar 20, 2020
2a1941d
Polish docstring
workingloong Mar 23, 2020
2077901
Polish docstring
workingloong Mar 23, 2020
cb55297
Add note for TF version
workingloong Mar 23, 2020
c4ffabe
Fix docstring by comments
workingloong Mar 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions elasticdl_preprocessing/layers/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import absolute_import, division, print_function

import tensorflow as tf


class Hashing(tf.keras.layers.Layer):
"""Distribute categorical feature values into a finite number of buckets
by hashing.

This layer converts a sequence of int or string to a sequence of int.
output_id = Hash(input_feature_string) % num_bins for string type input.
For int type input, the layer converts the value to string and then
processes it by the same formula. TensorFlow 2.2 has developed
`tf.keras.layers.preprocessing.Hashing` but not released it yet. So the
layer is a simple temporary version.
https://github.com/tensorflow/tensorflow/blob/r2.2/tensorflow/python/keras/layers/preprocessing/hashing.py

Note that the TensorFlow version with the layer must be greater than 2.0.0.

Example:
```python
layer = Hashing(num_bins=3)
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
layer(inp)
```
The output will be `[[1], [0], [1], [1], [2]]`

Arguments:
num_bins: Number of hash bins.
**kwargs: Keyword arguments to construct a layer.

Input: A string, int32 or int64 `tf.Tensor`,
`tf.SparseTensor` or `tf.RaggedTensor`

Output: An int64 tensor with the same shape as input.

"""

def __init__(self, num_bins, **kwargs):
if num_bins is None or num_bins <= 0:
raise ValueError(
"`num_bins` cannot be `None` or non-positive values."
)
super(Hashing, self).__init__(**kwargs)
self.num_bins = num_bins
self._supports_ragged_inputs = True

def call(self, inputs):
# Converts integer inputs to string.
if inputs.dtype.is_integer:
if isinstance(inputs, tf.SparseTensor):
inputs = tf.SparseTensor(
indices=inputs.indices,
values=tf.as_string(inputs.values),
dense_shape=inputs.dense_shape,
)
else:
inputs = tf.as_string(inputs)
if isinstance(inputs, tf.RaggedTensor):
return tf.ragged.map_flat_values(
tf.strings.to_hash_bucket_fast,
inputs,
num_buckets=self.num_bins,
name="hash",
)
elif isinstance(inputs, tf.SparseTensor):
sparse_values = inputs.values
sparse_hashed_values = tf.strings.to_hash_bucket_fast(
sparse_values, self.num_bins, name="hash"
)
return tf.SparseTensor(
indices=inputs.indices,
values=sparse_hashed_values,
dense_shape=inputs.dense_shape,
)
else:
return tf.strings.to_hash_bucket_fast(
inputs, self.num_bins, name="hash"
)

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {"num_bins": self.num_bins}
base_config = super(Hashing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
31 changes: 31 additions & 0 deletions elasticdl_preprocessing/tests/hashing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

import numpy as np
import tensorflow as tf

from elasticdl_preprocessing.layers.hashing import Hashing
from elasticdl_preprocessing.tests.test_utils import (
ragged_tensor_equal,
sparse_tensor_equal,
)


class HashingTest(unittest.TestCase):
def test_hashing(self):
hash_layer = Hashing(num_bins=3)
inp = np.asarray([["A"], ["B"], ["C"], ["D"], ["E"]])
hash_out = hash_layer(inp)
expected_out = np.array([[1], [0], [1], [1], [2]])
self.assertTrue(np.array_equal(hash_out.numpy(), expected_out))

ragged_in = tf.ragged.constant([["A", "B"], ["C", "D"], ["E"], []])
hash_out = hash_layer(ragged_in)
expected_ragged_out = tf.ragged.constant(
[[1, 0], [1, 1], [2], []], dtype=tf.int64
)
self.assertTrue(ragged_tensor_equal(hash_out, expected_ragged_out))

sparse_in = ragged_in.to_sparse()
hash_out = hash_layer(sparse_in)
expected_sparse_out = expected_ragged_out.to_sparse()
self.assertTrue(sparse_tensor_equal(hash_out, expected_sparse_out))