-
Notifications
You must be signed in to change notification settings - Fork 116
Layer to convert Tensor to SparseTensor dropping ignore values #1860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
035bc94
01aa93a
17e837d
83b40f3
b9446c7
d071af5
278894d
4ef65e7
3a1da78
6c2b8df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class ToSparse(tf.keras.layers.Layer): | ||
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. | ||
If the input is already a `SparseTensor`, just return it. | ||
|
||
Example : | ||
```python | ||
layer = ToSparse() | ||
inp = tf.constant([["A", ""], ["B", "C"]], tf.string) | ||
layer.call(inp) | ||
tf.SparseTensor( | ||
indices=np.array([[0, 0], [1, 0], [1, 1]]), | ||
values=np.array(["A", "B", "C"]), | ||
dense_shape=(2, 2), | ||
) | ||
``` | ||
|
||
Arguments: | ||
ignore_value: Entries in inputs equal to this value will be | ||
absent from the output `SparseTensor`. If `None`, default value of | ||
inputs dtype will be used ('' for `str`, -1 for `int`). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we expose this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -1 for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -1 for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Use -1 as the default ignore_value for |
||
|
||
Input shape: A numeric or string `Tensor` of shape | ||
`[batch_size, d1, ..., dm]` | ||
|
||
Output shape: An `SparseTensor` with the same shape as inputs | ||
""" | ||
|
||
def __init__(self, ignore_value=None): | ||
super(ToSparse, self).__init__() | ||
self.ignore_value = ignore_value | ||
|
||
def call(self, inputs): | ||
if isinstance(inputs, tf.SparseTensor): | ||
return inputs | ||
if self.ignore_value is None: | ||
workingloong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if inputs.dtype == tf.string: | ||
self.ignore_value = "" | ||
elif inputs.dtype.is_integer: | ||
self.ignore_value = -1 | ||
self.ignore_value = tf.cast(self.ignore_value, inputs.dtype) | ||
indices = tf.where(tf.not_equal(inputs, self.ignore_value)) | ||
values = tf.gather_nd(inputs, indices) | ||
dense_shape = tf.shape(inputs, out_type=tf.int64) | ||
return tf.SparseTensor( | ||
indices=indices, values=values, dense_shape=dense_shape | ||
) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = { | ||
"ignore_value": self.ignore_value, | ||
} | ||
base_config = super(ToSparse, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from elasticdl_preprocessing.layers.to_sparse import ToSparse | ||
from elasticdl_preprocessing.tests.test_utils import sparse_tensor_equal | ||
|
||
|
||
class ToSparseTest(unittest.TestCase): | ||
def test_to_sparse(self): | ||
layer = ToSparse() | ||
inp = tf.constant([["A", ""], ["B", "C"]], tf.string) | ||
output = layer.call(inp) | ||
expected_out = tf.SparseTensor( | ||
indices=np.array([[0, 0], [1, 0], [1, 1]]), | ||
values=np.array(["A", "B", "C"]), | ||
dense_shape=(2, 2), | ||
) | ||
self.assertTrue(sparse_tensor_equal(output, expected_out)) | ||
|
||
layer = ToSparse() | ||
inp = tf.constant([[12, -1], [45, 78]], tf.int64) | ||
output = layer.call(inp) | ||
expected_out = tf.SparseTensor( | ||
indices=np.array([[0, 0], [1, 0], [1, 1]]), | ||
values=np.array([12, 45, 78]), | ||
dense_shape=(2, 2), | ||
) | ||
self.assertTrue(sparse_tensor_equal(output, expected_out)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.