-
Notifications
You must be signed in to change notification settings - Fork 5.9k
split optimization ops on pserver to independenty blocks #10123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
7907b6a
e05f4df
39f6274
18e0b73
ba1e68d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" | ||
| RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" | ||
|
|
||
| GLOBAL_BLOCK_IDX = 0 | ||
|
|
||
|
|
||
| class VarBlock: | ||
| def __init__(self, varname, offset, size): | ||
|
|
@@ -368,8 +370,10 @@ def get_pserver_program(self, endpoint): | |
| else: | ||
| recv_inputs.append(single_trainer_var) | ||
|
|
||
| # step3 | ||
| optimize_block = pserver_program.create_block(0) | ||
| # step 3 | ||
| # each optimization op will has a optimize block | ||
| optimize_block = None | ||
|
|
||
| # step 4 | ||
| # Create a union-find data structure from optimize ops, | ||
| # If two ops are connected, we could add these two ops | ||
|
|
@@ -415,29 +419,35 @@ def __append_optimize_op__(op, block): | |
| else: | ||
| self._append_pserver_non_opt_ops(block, op) | ||
|
|
||
| append_block = optimize_block | ||
| # append lr decay ops to the child block if exists | ||
| lr_decay_block = None | ||
| lr_ops = self._get_lr_ops() | ||
| if len(lr_ops) > 0: | ||
| lr_decay_block = pserver_program.create_block(GLOBAL_BLOCK_IDX) | ||
| for _, op in enumerate(lr_ops): | ||
| self._append_pserver_non_opt_ops(append_block, op) | ||
|
|
||
| append_block = pserver_program.create_block(append_block.idx) | ||
| self._append_pserver_non_opt_ops(lr_decay_block, op) | ||
|
|
||
| # append op to the current block | ||
| per_opt_block = append_block | ||
| per_opt_block = None | ||
| pre_block_idx = GLOBAL_BLOCK_IDX | ||
|
||
| if lr_decay_block is not None: | ||
| pre_block_idx = lr_decay_block.idx | ||
| for idx, opt_op in enumerate(opt_op_on_pserver): | ||
| per_opt_block = pserver_program.create_block(pre_block_idx) | ||
| if optimize_block is None: | ||
| # first optimize block | ||
| optimize_block = per_opt_block | ||
| for _, op in enumerate(self.optimize_ops): | ||
| # optimizer is connected to itself | ||
| if ufind.is_connected(op, opt_op) and \ | ||
| op not in global_ops: | ||
| if ufind.is_connected(op, opt_op) and op not in global_ops: | ||
| __append_optimize_op__(op, per_opt_block) | ||
| if idx == len(opt_op_on_pserver) - 1 and global_ops: | ||
| per_opt_block = pserver_program.create_block(append_block.idx) | ||
|
|
||
| # append global ops | ||
| opt_state_block = None | ||
| if global_ops: | ||
| opt_state_block = pserver_program.create_block(per_opt_block.idx) | ||
| for glb_op in global_ops: | ||
|
||
| __append_optimize_op__(glb_op, per_opt_block) | ||
| __append_optimize_op__(glb_op, opt_state_block) | ||
|
|
||
| # NOT USED: single block version: | ||
| # | ||
|
|
@@ -451,10 +461,11 @@ def __append_optimize_op__(op, block): | |
| prefetch_block = None | ||
| if self.has_distributed_lookup_table: | ||
| pserver_index = self.pserver_endpoints.index(endpoint) | ||
| self._create_table_optimize_block(pserver_index, pserver_program, | ||
| append_block) | ||
| table_opt_block = self._create_table_optimize_block( | ||
| pserver_index, pserver_program, opt_state_block or | ||
| pserver_program.global_block()) | ||
| prefetch_block = self._create_prefetch_block( | ||
| pserver_index, pserver_program, optimize_block) | ||
| pserver_index, pserver_program, table_opt_block) | ||
|
|
||
| # NOTE: if has_distributed_lookup_table is False, then prefetch_block will | ||
| # not be executed, so it's safe to use optimize_block to hold the place | ||
|
|
@@ -724,6 +735,8 @@ def _clone_var(block, var, persistable=True): | |
| outputs=outputs, | ||
| attrs=table_opt_op.attrs) | ||
|
|
||
| return table_opt_block | ||
|
|
||
| # ====================== private transpiler functions ===================== | ||
| def _create_vars_from_blocklist(self, | ||
| program, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done