Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions python/paddle/fluid/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@ class SelectCase(object):
RECEIVE = 2

def __init__(self,
select,
case_idx,
case_to_execute,
channel_action_fn=None,
channel=None,
value=None):
value=None,
is_copy=False):
self.select = select
self.helper = LayerHelper('conditional_block')
self.main_program = self.helper.main_program
self.is_scalar_condition = True
Expand All @@ -99,7 +102,24 @@ def __init__(self,
self.action = (self.SEND
if channel_action_fn.__name__ == ('channel_send') else
self.RECEIVE) if channel_action_fn else self.DEFAULT
self.value = value

X = value
if self.action == self.SEND and is_copy:
# We create of copy of the data we want to send
copied_X = self.select.parent_block.create_var(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
dtype=value.dtype,
shape=value.shape,
lod_level=value.lod_level,
capacity=value.capacity
if hasattr(value, 'capacity') else None, )

self.select.parent_block.append_op(
type="assign", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X

self.value = X
self.channel = channel

def __enter__(self):
Expand Down Expand Up @@ -173,6 +193,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class Select(BlockGuard):
def __init__(self, name=None):
self.helper = LayerHelper('select', name=name)
self.parent_block = self.helper.main_program.current_block()
self.cases = []

super(Select, self).__init__(self.helper.main_program)
Expand All @@ -183,12 +204,12 @@ def __enter__(self):
super(Select, self).__enter__()
return self

def case(self, channel_action_fn, channel, value):
def case(self, channel_action_fn, channel, value, is_copy=False):
"""Create a new block for this condition.
"""
select_case = SelectCase(
len(self.cases), self.case_to_execute, channel_action_fn, channel,
value)
select_case = SelectCase(self,
len(self.cases), self.case_to_execute,
channel_action_fn, channel, value, is_copy)

self.cases.append(select_case)

Expand All @@ -197,7 +218,7 @@ def case(self, channel_action_fn, channel, value):
def default(self):
"""Create a default case block for this condition.
"""
default_case = SelectCase(len(self.cases), self.case_to_execute)
default_case = SelectCase(self, len(self.cases), self.case_to_execute)

self.cases.append(default_case)

Expand Down Expand Up @@ -346,17 +367,17 @@ def channel_send(channel, value, is_copy=False):

X = value

if is_copy is True:
if is_copy:
copied_X = helper.create_variable(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
dtype=value.dtype,
shape=value.shape,
lod_level=value.lod_level,
capacity=value.capacity)
capacity=value.capacity if hasattr(value, 'capacity') else None)

assign_op = channel_send_block.append_op(
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X})
type="assign", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X

channel_send_op = channel_send_block.append_op(
Expand Down
23 changes: 4 additions & 19 deletions python/paddle/fluid/tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,10 @@ def fibonacci(channel, quit_channel):
with while_op.block():
result2 = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
x_to_send_tmp = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)

# TODO(abhinav): Need to perform copy when doing a channel send.
# Once this is complete, we can remove these lines
assign(input=x, output=x_to_send_tmp)

with fluid.Select() as select:
with select.case(fluid.channel_send, channel,
x_to_send_tmp):
with select.case(
fluid.channel_send, channel, x, is_copy=True):
assign(input=x, output=x_tmp)
assign(input=y, output=x)
assign(elementwise_add(x=x_tmp, y=y), output=y)
Expand Down Expand Up @@ -230,21 +224,12 @@ def test_ping_pong(self):
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.FP64)

pong_result = self._create_tensor('pong_return_value',
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.FP64)

def ping(ch, message):
message_to_send_tmp = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.FP64, value=0)

assign(input=message, output=message_to_send_tmp)
fluid.channel_send(ch, message_to_send_tmp)
fluid.channel_send(ch, message, is_copy=True)

def pong(ch1, ch2):
fluid.channel_recv(ch1, ping_result)
assign(input=ping_result, output=pong_result)
fluid.channel_send(ch2, pong_result)
fluid.channel_send(ch2, ping_result, is_copy=True)

pings = fluid.make_channel(
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
Expand Down