Skip to content

Commit 3a57c42

Browse files
authored
Hashing preprocessing layer to convert categorical inputs to zero-based integer values (#1859)
* Add RoundIdentity layer to transform numeric values to integer ids * Fix the docstring * Config pre-commit with elasticdl_preprocessing * Fix the docstring by comments * Add hashing layer to preprocess categorical feature * Add a method to check ragged tensors equal * Add an unit test for hashing * fix docstring * Add description about Keras hashing layer * Remove unused import * Fix example * Fix example * fix docstring * Add hashing layer to convert catgorical inputs to integer values * Remove index url * Use TensorFlow public API * Polish docstring * Polish docstring * Add note for TF version * Fix docstring by comments
1 parent 4a01631 commit 3a57c42

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
import tensorflow as tf
4+
5+
6+
class Hashing(tf.keras.layers.Layer):
7+
"""Distribute categorical feature values into a finite number of buckets
8+
by hashing.
9+
10+
This layer converts a sequence of int or string to a sequence of int.
11+
output_id = Hash(input_feature_string) % num_bins for string type input.
12+
For int type input, the layer converts the value to string and then
13+
processes it by the same formula. TensorFlow 2.2 has developed
14+
`tf.keras.layers.preprocessing.Hashing` but not released it yet. So the
15+
layer is a simple temporary version.
16+
https://github.com/tensorflow/tensorflow/blob/r2.2/tensorflow/python/keras/layers/preprocessing/hashing.py
17+
18+
Note that the TensorFlow version with the layer must be greater than 2.0.0.
19+
20+
Example:
21+
```python
22+
layer = Hashing(num_bins=3)
23+
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
24+
layer(inp)
25+
```
26+
The output will be `[[1], [0], [1], [1], [2]]`
27+
28+
Arguments:
29+
num_bins: Number of hash bins.
30+
**kwargs: Keyword arguments to construct a layer.
31+
32+
Input: A string, int32 or int64 `tf.Tensor`,
33+
`tf.SparseTensor` or `tf.RaggedTensor`
34+
35+
Output: An int64 tensor with the same shape as input.
36+
37+
"""
38+
39+
def __init__(self, num_bins, **kwargs):
40+
if num_bins is None or num_bins <= 0:
41+
raise ValueError(
42+
"`num_bins` cannot be `None` or non-positive values."
43+
)
44+
super(Hashing, self).__init__(**kwargs)
45+
self.num_bins = num_bins
46+
self._supports_ragged_inputs = True
47+
48+
def call(self, inputs):
49+
# Converts integer inputs to string.
50+
if inputs.dtype.is_integer:
51+
if isinstance(inputs, tf.SparseTensor):
52+
inputs = tf.SparseTensor(
53+
indices=inputs.indices,
54+
values=tf.as_string(inputs.values),
55+
dense_shape=inputs.dense_shape,
56+
)
57+
else:
58+
inputs = tf.as_string(inputs)
59+
if isinstance(inputs, tf.RaggedTensor):
60+
return tf.ragged.map_flat_values(
61+
tf.strings.to_hash_bucket_fast,
62+
inputs,
63+
num_buckets=self.num_bins,
64+
name="hash",
65+
)
66+
elif isinstance(inputs, tf.SparseTensor):
67+
sparse_values = inputs.values
68+
sparse_hashed_values = tf.strings.to_hash_bucket_fast(
69+
sparse_values, self.num_bins, name="hash"
70+
)
71+
return tf.SparseTensor(
72+
indices=inputs.indices,
73+
values=sparse_hashed_values,
74+
dense_shape=inputs.dense_shape,
75+
)
76+
else:
77+
return tf.strings.to_hash_bucket_fast(
78+
inputs, self.num_bins, name="hash"
79+
)
80+
81+
def compute_output_shape(self, input_shape):
82+
return input_shape
83+
84+
def get_config(self):
85+
config = {"num_bins": self.num_bins}
86+
base_config = super(Hashing, self).get_config()
87+
return dict(list(base_config.items()) + list(config.items()))
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import unittest
2+
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
from elasticdl_preprocessing.layers.hashing import Hashing
7+
from elasticdl_preprocessing.tests.test_utils import (
8+
ragged_tensor_equal,
9+
sparse_tensor_equal,
10+
)
11+
12+
13+
class HashingTest(unittest.TestCase):
14+
def test_hashing(self):
15+
hash_layer = Hashing(num_bins=3)
16+
inp = np.asarray([["A"], ["B"], ["C"], ["D"], ["E"]])
17+
hash_out = hash_layer(inp)
18+
expected_out = np.array([[1], [0], [1], [1], [2]])
19+
self.assertTrue(np.array_equal(hash_out.numpy(), expected_out))
20+
21+
ragged_in = tf.ragged.constant([["A", "B"], ["C", "D"], ["E"], []])
22+
hash_out = hash_layer(ragged_in)
23+
expected_ragged_out = tf.ragged.constant(
24+
[[1, 0], [1, 1], [2], []], dtype=tf.int64
25+
)
26+
self.assertTrue(ragged_tensor_equal(hash_out, expected_ragged_out))
27+
28+
sparse_in = ragged_in.to_sparse()
29+
hash_out = hash_layer(sparse_in)
30+
expected_sparse_out = expected_ragged_out.to_sparse()
31+
self.assertTrue(sparse_tensor_equal(hash_out, expected_sparse_out))

0 commit comments

Comments
 (0)