-
Notifications
You must be signed in to change notification settings - Fork 116
Add RoundIdentity layer to transform numeric values to integer ids #1842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
workingloong
merged 9 commits into
sql-machine-learning:develop
from
workingloong:add_round_identity
Mar 20, 2020
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
be3c822
Add RoundIdentity layer to transform numeric values to integer ids
workingloong ea319da
Fix the docstring
workingloong 69495b6
Config pre-commit with elasticdl_preprocessing
workingloong d082113
Merge branch 'develop' into add_round_identity
workingloong 4a25f34
Fix the docstring by comments
workingloong 7cfe9cd
Merge branch 'develop' into add_round_identity
workingloong 9dca79c
Add a method to check ragged tensors equal
workingloong 1c39600
Remove unused import
workingloong fab901a
Fix example
workingloong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import tensorflow as tf | ||
from tensorflow.python.ops.ragged import ragged_functional_ops, ragged_tensor | ||
|
||
|
||
class RoundIdentity(tf.keras.layers.Layer): | ||
"""Implements numeric feature rounding with a max value. | ||
|
||
This layer transforms numeric inputs to integer output. It is a special | ||
case of bucketizing to bins. The max value in the layer is the number of | ||
bins. | ||
|
||
Example : | ||
```python | ||
layer = RoundIdentity(max_value=5) | ||
inp = np.asarray([[1.2], [1.6], [0.2], [3.1], [4.9]]) | ||
layer(inputs) | ||
[[1], [2], [0], [3], [5]] | ||
``` | ||
|
||
Arguments: | ||
num_buckets: Range of inputs and outputs is `[0, num_buckets)`. | ||
**kwargs: Keyword arguments to construct a layer. | ||
|
||
Input shape: A numeric tensor of shape | ||
workingloong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
`[batch_size, d1, ..., dm]` | ||
|
||
Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` | ||
|
||
""" | ||
|
||
def __init__(self, num_buckets, default_value=0): | ||
super(RoundIdentity, self).__init__() | ||
self.num_buckets = tf.cast(num_buckets, tf.int64) | ||
self.default_value = tf.cast(default_value, tf.int64) | ||
|
||
def call(self, inputs): | ||
if isinstance(inputs, tf.SparseTensor): | ||
id_values = self._round_and_truncate(inputs.values) | ||
result = tf.SparseTensor( | ||
indices=inputs.indices, | ||
values=id_values, | ||
dense_shape=inputs.dense_shape, | ||
) | ||
elif ragged_tensor.is_ragged(inputs): | ||
result = ragged_functional_ops.map_flat_values( | ||
self._round_and_truncate, inputs | ||
) | ||
else: | ||
result = self._round_and_truncate(inputs) | ||
return tf.cast(result, tf.int64) | ||
|
||
def _round_and_truncate(self, values): | ||
values = tf.keras.backend.round(values) | ||
values = tf.cast(values, tf.int64) | ||
values = tf.where( | ||
tf.logical_or(values < 0, values > self.num_buckets), | ||
x=tf.fill(dims=tf.shape(values), value=self.default_value), | ||
y=values, | ||
) | ||
return values | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = { | ||
"num_buckets": self.num_buckets, | ||
"default_value": self.default_value, | ||
} | ||
base_config = super(RoundIdentity, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from elasticdl_preprocessing.layers.round_identity import RoundIdentity | ||
|
||
|
||
class RoundIdentityTest(unittest.TestCase): | ||
def test_round_indentity(self): | ||
round_identity = RoundIdentity(num_buckets=10) | ||
|
||
dense_input = tf.constant([[1.2], [1.6], [0.2], [3.1], [4.9]]) | ||
output = round_identity(dense_input) | ||
expected_out = np.array([[1], [2], [0], [3], [5]]) | ||
self.assertTrue(np.array_equal(output.numpy(), expected_out)) | ||
|
||
ragged_input = tf.ragged.constant([[1.1, 3.4], [0.5]]) | ||
ragged_output = round_identity(ragged_input) | ||
ragged_output_values = ragged_output.values.numpy() | ||
expected_out = np.array([1.0, 3.0, 0.0]) | ||
self.assertTrue(np.array_equal(ragged_output_values, expected_out)) | ||
|
||
sparse_input = ragged_input.to_sparse() | ||
sparse_output = round_identity(sparse_input) | ||
sparse_output_values = sparse_output.values | ||
self.assertTrue(np.array_equal(sparse_output_values, expected_out)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.