1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import os
1516import core
1617import framework
1718import executor
2021
2122# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
2223import optimizer as opt_module
24+ import distribute_transpiler
2325
2426__all__ = [
2527 'Trainer' ,
@@ -76,22 +78,61 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
7678 raise TypeError (
7779 "The optimizer should be an instance of Optimizer" )
7880
79- optimizer .minimize (loss )
81+ optimize_ops , params_grads = optimizer .minimize (loss )
8082
8183 self .place = Trainer ._check_and_get_place (place )
8284
85+ self .dist_transpile_if_necessary (optimize_ops , params_grads )
86+
8387 # 2. move the default_main_program to self.program and run the
8488 # default_startup program on an empty core.Scope()
8589 # Run startup program
86- exe = executor .Executor (place )
87- exe .run (self .startup_program , scope = self .scope )
90+ with self ._prog_and_scope_guard ():
91+ exe = executor .Executor (place )
92+ exe .run (self .startup_program )
8893
8994 if param_path :
9095 # load params from param_path into scope
9196 # TODO(yuyang): This depends on parameters implementation.
9297 pass
9398
94- # TODO(helin): support distributed training
99+ def dist_transpile_if_necessary (self , optimize_ops , params_grads ):
100+ if "PADDLE_TRAINING_ROLE" not in os .environ :
101+ return
102+
103+ # the port of all pservers, needed by both trainer and pserver
104+ port = os .getenv ("PADDLE_PSERVER_PORT" , "6174" )
105+ # comma separated ips of all pservers, needed by trainer and
106+ # pserver
107+ pserver_ips = os .getenv ("PADDLE_PSERVER_IPS" , "" )
108+ eplist = []
109+ for ip in pserver_ips .split ("," ):
110+ eplist .append (':' .join ([ip , port ]))
111+ pserver_endpoints = "," .join (eplist )
112+ # total number of workers/trainers in the job, needed by
113+ # trainer and pserver
114+ trainers = int (os .getenv ("PADDLE_TRAINERS" ))
115+ # the IP of the local machine, needed by pserver only
116+ current_endpoint = os .getenv ("PADDLE_CURRENT_IP" , "" ) + ":" + port
117+ # the unique trainer id, starting from 0, needed by trainer
118+ # only
119+ trainer_id = int (os .getenv ("PADDLE_TRAINER_ID" , "0" ))
120+ # the role, should be either PSERVER or TRAINER
121+ training_role = os .getenv ("PADDLE_TRAINING_ROLE" )
122+ with self ._prog_and_scope_guard ():
123+ t = distribute_transpiler .DistributeTranspiler ()
124+ t .transpile (
125+ trainer_id , pservers = pserver_endpoints , trainers = trainers )
126+ if training_role == "PSERVER" :
127+ self .train_program = t .get_pserver_program (current_endpoint )
128+ self .startup_program = t .get_startup_program (current_endpoint ,
129+ self .train_program )
130+ elif training_role == "TRAINER" :
131+ self .train_program = t .get_trainer_program ()
132+ else :
133+ raise ValueError (
134+ 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
135+ )
95136
96137 def train (self ,
97138 num_epochs ,
@@ -117,6 +158,13 @@ def train(self,
117158 raise NotImplementedError (
118159 "Parallel Executor version of trainer is not implemented" )
119160
161+ training_role = os .getenv ("PADDLE_TRAINING_ROLE" , "" )
162+ if training_role == "PSERVER" :
163+ with self ._prog_and_scope_guard ():
164+ exe = executor .Executor (self .place )
165+ exe .run ()
166+ return
167+
120168 self ._train_by_executor (num_epochs , event_handler , reader , feed_order )
121169
122170 def test (self , reader ):
0 commit comments