Skip to content

Commit 35c373d

Browse files
authored
Support copy in Fluid channels (#9138)
* Support copy in Fluid channels * Address PR review comments
1 parent 484cff6 commit 35c373d

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

python/paddle/fluid/concurrency.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0):
131131
return channel
132132

133133

134-
def channel_send(channel, value):
134+
def channel_send(channel, value, copy=False):
135135
"""
136136
Sends a value through a channel variable. Used by an unbuffered or buffered
137137
channel to pass data from within or to a concurrent Go block, where
@@ -141,6 +141,8 @@ def channel_send(channel, value):
141141
channel (Variable|Channel): Channel variable created using
142142
`make_channel`.
143143
value (Variable): Value to send to channel
144+
copy (bool): Copy data while channel send. If False, then data
145+
is moved. The input cannot be used after move.
144146
Returns:
145147
Variable: The boolean status on whether or not the channel
146148
successfully sent the passed value.
@@ -162,11 +164,26 @@ def channel_send(channel, value):
162164
type=core.VarDesc.VarType.LOD_TENSOR,
163165
dtype=core.VarDesc.VarType.BOOL)
164166

167+
X = value
168+
169+
if copy is True:
170+
copied_X = helper.create_variable(
171+
name=unique_name.generate(value.name + '_copy'),
172+
type=value.type,
173+
dtype=value.dtype,
174+
shape=value.shape,
175+
lod_level=value.lod_level,
176+
capacity=value.capacity)
177+
178+
assign_op = channel_send_block.append_op(
179+
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X})
180+
X = copied_X
181+
165182
channel_send_op = channel_send_block.append_op(
166183
type="channel_send",
167184
inputs={
168185
"Channel": channel,
169-
"X": value,
186+
"X": X,
170187
},
171188
outputs={"Status": status})
172189

0 commit comments

Comments
 (0)