Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions python/paddle/v2/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,36 @@


def split_lod_tensor(input, mask, level=0):
"""
**split_lod_tensor**

This function takes in an input that contains the complete lod information,
and takes in a mask which is used to mask certain parts of the input.
The output is the true branch and the false branch with the mask applied to
the input at a certain level in the tensor.

Args:
input(tuple|list|None): The input tensor that contains complete
lod information needed to construct the output.
mask(list): A bool column vector which masks the input.
level(int): The specific lod level to rank.

Returns:
Variable: The true branch of tensor as per the mask applied to input.
Variable: The false branch of tensor as per the mask applied to input.

Examples:
.. code-block:: python

x = layers.data(name='x', shape=[1])
x.persistable = True

y = layers.data(name='y', shape=[1])
y.persistable = True

out_true, out_false = layers.split_lod_tensor(
input=x, mask=y, level=level)
"""
helper = LayerHelper('split_lod_tensor', **locals())
out_true = helper.create_tmp_variable(dtype=input.dtype)
out_false = helper.create_tmp_variable(dtype=input.dtype)
Expand All @@ -31,6 +61,40 @@ def split_lod_tensor(input, mask, level=0):


def merge_lod_tensor(in_true, in_false, x, mask, level=0):
"""
**merge_lod_tensor**

This function takes in an input :math:`x`, the True branch, the False
branch and a binary :math:`mask`. Using this information, this function
merges the True and False branches of the tensor into a single Output
at a certain lod level indiacted by :math:`level`.

Args:
in_true(tuple|list|None): The True branch to be merged.
in_false(tuple|list|None): The False branch to be merged.
x(tuple|list|None): The input tensor that contains complete
lod information needed to construct the output.
mask(list): A bool column vector which masks the input.
level(int): The specific lod level to rank.

Returns:
Variable: The merged output tensor.

Examples:
.. code-block:: python

x = layers.data(
name='x', shape=[1], dtype='float32', stop_gradient=False)
y = layers.data(
name='y', shape=[1], dtype='bool', stop_gradient=False)

level = 0

out_true, out_false = layers.split_lod_tensor(
input=x, mask=y, level=level)
out = layers.merge_lod_tensor(
in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
"""
helper = LayerHelper('merge_lod_tensor', **locals())
out = helper.create_tmp_variable(dtype=in_true.dtype)
helper.append_op(
Expand Down