Skip to content

Support concatenating the tensor adding offset. #1846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
66 changes: 66 additions & 0 deletions elasticdl_preprocessing/layers/concatenate_with_offset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import tensorflow as tf


class ConcatenateWithOffset(tf.keras.layers.Concatenate):
"""Layer that add offset to each id tensor in the input list and
then concatenate these tensors.

It takes as input a list of tensors and returns a single tensor.
Firstly, it will add an offset in offsets for each tensor in inputs.
Then concatenate them to a single tensor. The tensor in inputs
must have the same type, `Tensor` or `RaggedTensor` or `SparseTensor` and
the same shape.

Example :
```python
a1 = tf.constant([[1], [1], [1]])
a2 = tf.constant([[2], [2], [2]])
offsets = [0, 10]
layer = ConcatenateWithOffset(offsets=offsets, axis=1)
layer([a1, a2])
[[ 1 12]
[ 1 12]
[ 1 12]]
```

Arguments:
offsets: numeric list to add
axis: Axis along which to concatenate.
**kwargs: standard layer keyword arguments.
"""

def __init__(self, offsets, axis=-1):
super(ConcatenateWithOffset, self).__init__()
self.offsets = offsets
self.axis = axis

def call(self, inputs):
ids_with_offset = []
if len(self.offsets) != len(inputs):
raise ValueError(
"The offsets length is not equal to inputs length"
"the inputs are {}, offsets are {}".format(
inputs, self.offsets
)
)
for i, tensor in enumerate(inputs):
if isinstance(tensor, tf.SparseTensor):
ids_with_offset.append(
tf.SparseTensor(
indices=tensor.indices,
values=tensor.values + self.offsets[i],
dense_shape=tensor.dense_shape,
)
)
else:
ids_with_offset.append(tensor + self.offsets[i])

if isinstance(ids_with_offset[0], tf.SparseTensor):
result = tf.sparse.concat(
axis=self.axis, sp_inputs=ids_with_offset
)
else:
result = tf.keras.layers.concatenate(
ids_with_offset, axis=self.axis
)
return result
40 changes: 40 additions & 0 deletions elasticdl_preprocessing/tests/concatenate_with_offset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

import numpy as np
import tensorflow as tf

from elasticdl_preprocessing.layers.concatenate_with_offset import (
ConcatenateWithOffset,
)
from elasticdl_preprocessing.tests.test_utils import (
ragged_tensor_equal,
sparse_tensor_equal,
)


class ConcatenateWithOffsetTest(unittest.TestCase):
def test_concatenate_with_offset(self):
tensor_1 = tf.constant([[1], [1], [1]])
tensor_2 = tf.constant([[2], [2], [2]])
offsets = [0, 10]
concat_layer = ConcatenateWithOffset(offsets=offsets, axis=1)

output = concat_layer([tensor_1, tensor_2])
expected_out = np.array([[1, 12], [1, 12], [1, 12]])
self.assertTrue(np.array_equal(output.numpy(), expected_out))

ragged_tensor_1 = tf.ragged.constant([[1], [], [1]], dtype=tf.int64)
ragged_tensor_2 = tf.ragged.constant([[2], [2], []], dtype=tf.int64)
output = concat_layer([ragged_tensor_1, ragged_tensor_2])
expected_out = tf.ragged.constant([[1, 12], [12], [1]], dtype=tf.int64)
self.assertTrue(ragged_tensor_equal(output, expected_out))

sparse_tensor_1 = ragged_tensor_1.to_sparse()
sparse_tensor_2 = ragged_tensor_2.to_sparse()
output = concat_layer([sparse_tensor_1, sparse_tensor_2])
expected_out = tf.SparseTensor(
indices=np.array([[0, 0], [0, 1], [1, 1], [2, 0]]),
values=np.array([1, 12, 12, 1]),
dense_shape=(3, 2),
)
self.assertTrue(sparse_tensor_equal(output, expected_out))
16 changes: 12 additions & 4 deletions elasticdl_preprocessing/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.ragged import ragged_tensor


def sparse_tensor_equal(sp_a, sp_b):
if not isinstance(sp_a, tf.SparseTensor) or not isinstance(
sp_b, tf.SparseTensor
):
return False

if not np.array_equal(sp_a.dense_shape.numpy(), sp_b.dense_shape.numpy()):
return False

Expand All @@ -20,15 +24,19 @@ def sparse_tensor_equal(sp_a, sp_b):


def ragged_tensor_equal(rt_a, rt_b):
print(rt_a, rt_b)
if not isinstance(rt_a, tf.RaggedTensor) or not isinstance(
rt_b, tf.RaggedTensor
):
return False

if rt_a.shape.as_list() != rt_b.shape.as_list():
return False

for i in range(rt_a.shape[0]):
sub_rt_a = rt_a[i]
sub_rt_b = rt_b[i]
if ragged_tensor.is_ragged(sub_rt_a) and ragged_tensor.is_ragged(
sub_rt_b
if isinstance(sub_rt_a, tf.RaggedTensor) and isinstance(
sub_rt_b, tf.RaggedTensor
):
if not ragged_tensor_equal(sub_rt_a, sub_rt_b):
return False
Expand Down