Skip to content

Support concatenating the tensor adding offset. #1846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
65 changes: 65 additions & 0 deletions elasticdl_preprocessing/layers/concatenate_with_offset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import tensorflow as tf


class ConcatenateWithOffset(tf.keras.layers.Concatenate):
"""Layer that add offset for tensor in the list of inputs and
concatenate the tensors.

It takes as input a list of tensors and returns a single tensor.
Firstly, it will add an offset in offsets for each tensor in inputs.
Then concatenate them to a single tensor. The tensor in inputs
must have the same type, `Tensor` or `RaggedTensor` or `SparseTensor` and
the same shape.

Example :
```python
a1 = tf.constant([[1], [1], [1]])
a2 = tf.constant([[2], [2], [2]])
offsets = [0, 10]
layer = ConcatenateWithOffset(offsets=offsets, axis=1)
layer([a1, a2])
[[ 1 12]
[ 1 12]
[ 1 12]]
```

Arguments:
offsets: numeric list to add
axis: Axis along which to concatenate.
**kwargs: standard layer keyword arguments.
"""
def __init__(self, offsets, axis=-1):
super(ConcatenateWithOffset, self).__init__()
self.offsets = offsets
self.axis = axis

def call(self, inputs):
ids_with_offset = []
if len(self.offsets) != len(inputs):
raise ValueError(
"The offsets length is not equal to inputs length"
"the inputs are {}, offsets are {}".format(
inputs, self.offsets
)
)
for i, tensor in enumerate(inputs):
if isinstance(tensor, tf.SparseTensor):
ids_with_offset.append(
tf.SparseTensor(
indices=tensor.indices,
values=tensor.values + self.offsets[i],
dense_shape=tensor.dense_shape,
)
)
else:
ids_with_offset.append(tensor + self.offsets[i])

if isinstance(ids_with_offset[0], tf.SparseTensor):
result = tf.sparse.concat(
axis=self.axis, sp_inputs=ids_with_offset
)
else:
result = tf.keras.layers.concatenate(
ids_with_offset, axis=self.axis
)
return result
31 changes: 31 additions & 0 deletions elasticdl_preprocessing/tests/concatenate_with_offset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

import numpy as np
import tensorflow as tf

from elasticdl_preprocessing.layers.concatenate_with_offset import (
ConcatenateWithOffset
)


class ConcatenateWithOffsetTest(unittest.TestCase):
def test_concatenate_with_offset(self):
tensor_1 = tf.constant([[1], [1], [1]])
tensor_2 = tf.constant([[2], [2], [2]])
offsets = [0, 10]
concat_layer = ConcatenateWithOffset(offsets=offsets, axis=1)

output = concat_layer([tensor_1, tensor_2])
expected_out = np.array([[1, 12], [1, 12], [1, 12]])
self.assertTrue(np.array_equal(output.numpy(), expected_out))

ragged_tensor_1 = tf.ragged.constant([[1], [], [1]])
ragged_tensor_2 = tf.ragged.constant([[2], [2], []])
output = concat_layer([ragged_tensor_1, ragged_tensor_2])
expected_out = np.array([1, 12, 12, 1])
self.assertTrue(np.array_equal(output.values.numpy(), expected_out))

sparse_tensor_1 = ragged_tensor_1.to_sparse()
sparse_tensor_2 = ragged_tensor_2.to_sparse()
output = concat_layer([sparse_tensor_1, sparse_tensor_2])
self.assertTrue(np.array_equal(output.values.numpy(), expected_out))