diff --git a/elasticdl_preprocessing/layers/hashing.py b/elasticdl_preprocessing/layers/hashing.py new file mode 100644 index 000000000..23ddf78cc --- /dev/null +++ b/elasticdl_preprocessing/layers/hashing.py @@ -0,0 +1,87 @@ +from __future__ import absolute_import, division, print_function + +import tensorflow as tf + + +class Hashing(tf.keras.layers.Layer): + """Distribute categorical feature values into a finite number of buckets + by hashing. + + This layer converts a sequence of int or string to a sequence of int. + output_id = Hash(input_feature_string) % num_bins for string type input. + For int type input, the layer converts the value to string and then + processes it by the same formula. TensorFlow 2.2 has developed + `tf.keras.layers.preprocessing.Hashing` but not released it yet. So the + layer is a simple temporary version. + https://github.com/tensorflow/tensorflow/blob/r2.2/tensorflow/python/keras/layers/preprocessing/hashing.py + + Note that the TensorFlow version with the layer must be greater than 2.0.0. + + Example: + ```python + layer = Hashing(num_bins=3) + inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + layer(inp) + ``` + The output will be `[[1], [0], [1], [1], [2]]` + + Arguments: + num_bins: Number of hash bins. + **kwargs: Keyword arguments to construct a layer. + + Input: A string, int32 or int64 `tf.Tensor`, + `tf.SparseTensor` or `tf.RaggedTensor` + + Output: An int64 tensor with the same shape as input. + + """ + + def __init__(self, num_bins, **kwargs): + if num_bins is None or num_bins <= 0: + raise ValueError( + "`num_bins` cannot be `None` or non-positive values." + ) + super(Hashing, self).__init__(**kwargs) + self.num_bins = num_bins + self._supports_ragged_inputs = True + + def call(self, inputs): + # Converts integer inputs to string. + if inputs.dtype.is_integer: + if isinstance(inputs, tf.SparseTensor): + inputs = tf.SparseTensor( + indices=inputs.indices, + values=tf.as_string(inputs.values), + dense_shape=inputs.dense_shape, + ) + else: + inputs = tf.as_string(inputs) + if isinstance(inputs, tf.RaggedTensor): + return tf.ragged.map_flat_values( + tf.strings.to_hash_bucket_fast, + inputs, + num_buckets=self.num_bins, + name="hash", + ) + elif isinstance(inputs, tf.SparseTensor): + sparse_values = inputs.values + sparse_hashed_values = tf.strings.to_hash_bucket_fast( + sparse_values, self.num_bins, name="hash" + ) + return tf.SparseTensor( + indices=inputs.indices, + values=sparse_hashed_values, + dense_shape=inputs.dense_shape, + ) + else: + return tf.strings.to_hash_bucket_fast( + inputs, self.num_bins, name="hash" + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = {"num_bins": self.num_bins} + base_config = super(Hashing, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/elasticdl_preprocessing/tests/hashing_test.py b/elasticdl_preprocessing/tests/hashing_test.py new file mode 100644 index 000000000..fa9b8cb33 --- /dev/null +++ b/elasticdl_preprocessing/tests/hashing_test.py @@ -0,0 +1,31 @@ +import unittest + +import numpy as np +import tensorflow as tf + +from elasticdl_preprocessing.layers.hashing import Hashing +from elasticdl_preprocessing.tests.test_utils import ( + ragged_tensor_equal, + sparse_tensor_equal, +) + + +class HashingTest(unittest.TestCase): + def test_hashing(self): + hash_layer = Hashing(num_bins=3) + inp = np.asarray([["A"], ["B"], ["C"], ["D"], ["E"]]) + hash_out = hash_layer(inp) + expected_out = np.array([[1], [0], [1], [1], [2]]) + self.assertTrue(np.array_equal(hash_out.numpy(), expected_out)) + + ragged_in = tf.ragged.constant([["A", "B"], ["C", "D"], ["E"], []]) + hash_out = hash_layer(ragged_in) + expected_ragged_out = tf.ragged.constant( + [[1, 0], [1, 1], [2], []], dtype=tf.int64 + ) + self.assertTrue(ragged_tensor_equal(hash_out, expected_ragged_out)) + + sparse_in = ragged_in.to_sparse() + hash_out = hash_layer(sparse_in) + expected_sparse_out = expected_ragged_out.to_sparse() + self.assertTrue(sparse_tensor_equal(hash_out, expected_sparse_out))