Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions python/paddle/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"

GLOBAL_BLOCK_IDX = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



class VarBlock:
def __init__(self, varname, offset, size):
Expand Down Expand Up @@ -368,8 +370,10 @@ def get_pserver_program(self, endpoint):
else:
recv_inputs.append(single_trainer_var)

# step3
optimize_block = pserver_program.create_block(0)
# step 3
# each optimization op will has a optimize block
optimize_block = None

# step 4
# Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops
Expand Down Expand Up @@ -415,29 +419,35 @@ def __append_optimize_op__(op, block):
else:
self._append_pserver_non_opt_ops(block, op)

append_block = optimize_block
# append lr decay ops to the child block if exists
lr_decay_block = None
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block(GLOBAL_BLOCK_IDX)
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(append_block, op)

append_block = pserver_program.create_block(append_block.idx)
self._append_pserver_non_opt_ops(lr_decay_block, op)

# append op to the current block
per_opt_block = append_block
per_opt_block = None
pre_block_idx = GLOBAL_BLOCK_IDX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No globals is needed here, just record the latest block id.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if lr_decay_block is not None:
pre_block_idx = lr_decay_block.idx
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
if optimize_block is None:
# first optimize block
optimize_block = per_opt_block
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \
op not in global_ops:
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block)
if idx == len(opt_op_on_pserver) - 1 and global_ops:
per_opt_block = pserver_program.create_block(append_block.idx)

# append global ops
opt_state_block = None
if global_ops:
opt_state_block = pserver_program.create_block(per_opt_block.idx)
for glb_op in global_ops:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can move inside if statement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

__append_optimize_op__(glb_op, per_opt_block)
__append_optimize_op__(glb_op, opt_state_block)

# NOT USED: single block version:
#
Expand All @@ -451,10 +461,11 @@ def __append_optimize_op__(op, block):
prefetch_block = None
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
self._create_table_optimize_block(pserver_index, pserver_program,
append_block)
table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, opt_state_block or
pserver_program.global_block())
prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, optimize_block)
pserver_index, pserver_program, table_opt_block)

# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
Expand Down Expand Up @@ -724,6 +735,8 @@ def _clone_var(block, var, persistable=True):
outputs=outputs,
attrs=table_opt_op.attrs)

return table_opt_block

# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self,
program,
Expand Down