1919import data_feeder
2020import contextlib
2121import io
22- import transpiler
22+ import unique_name
2323
2424# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
2525import optimizer as opt_module
@@ -56,26 +56,62 @@ def __init__(self, epoch_id, step_id):
5656 self .step = step_id
5757
5858
59+ def check_and_get_place (place ):
60+ """
61+ Check the type of place or get the default place
62+ Args:
63+ place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
64+
65+ Raises:
66+ TypeError if the type mismatched.
67+
68+ Returns:
69+ the original place if it is not None.
70+ if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
71+ Otherwise returns CPUPlace by default.
72+ """
73+ if place is None :
74+ if core .is_compiled_with_cuda ():
75+ return core .CUDAPlace (0 )
76+ else :
77+ return core .CPUPlace ()
78+ else :
79+ if not isinstance (place , core .CUDAPlace ) and not isinstance (
80+ place , core .CPUPlace ):
81+ raise TypeError ("Place should be either CUDAPlace or CPUPlace" )
82+ return place
83+
84+
5985class Trainer (object ):
6086 """
6187
6288 Args:
63- program_func(callable): A function which will return loss. The loss must be a scaler.
89+ train_func(callable): A function which will return loss. The loss must be a scalar.
90+ infer_func(callable): A function which will return predict, used to save inference model
6491 optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
6592 place: The device place of this trainer.
6693 """
6794
68- def __init__ (self , program_func , optimizer , param_path = None , place = None ):
95+ def __init__ (self ,
96+ train_func ,
97+ infer_func ,
98+ optimizer ,
99+ param_path = None ,
100+ place = None ):
69101 # 1. we need to generate a framework.Program by calling
70102 # program_func. Reference: fluid.program_guard in
71103 # test_word2vec.py
104+ if not isinstance (optimizer , opt_module .Optimizer ):
105+ raise TypeError ("The optimizer should be an instance of Optimizer" )
106+
107+ self .infer_func = infer_func
72108 self .scope = core .Scope ()
73109
74110 self .startup_program = framework .Program ()
75111 self .train_program = framework .Program ()
76112
77113 with framework .program_guard (self .train_program , self .startup_program ):
78- program_func_outs = program_func ()
114+ program_func_outs = train_func ()
79115 self .test_outputs = program_func_outs if isinstance (
80116 program_func_outs , list ) else [program_func_outs ]
81117 self .test_program = self .train_program .clone ()
@@ -86,9 +122,9 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
86122 loss = self .test_outputs [0 ]
87123 optimize_ops , params_grads = optimizer .minimize (loss )
88124
89- self .place = Trainer . _check_and_get_place (place )
125+ self .place = check_and_get_place (place )
90126
91- self .dist_transpile_if_necessary (optimize_ops , params_grads )
127+ self ._dist_transpile_if_necessary (optimize_ops , params_grads )
92128
93129 # 2. move the default_main_program to self.program and run the
94130 # default_startup program on an empty core.Scope()
@@ -101,7 +137,7 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
101137 # load params from param_path into scope
102138 io .load_persistables (exe , dirname = param_path )
103139
104- def dist_transpile_if_necessary (self , optimize_ops , params_grads ):
140+ def _dist_transpile_if_necessary (self , optimize_ops , params_grads ):
105141 if "PADDLE_TRAINING_ROLE" not in os .environ :
106142 return
107143
@@ -190,31 +226,14 @@ def save_params(self, param_path):
190226 exe = executor .Executor (self .place )
191227 io .save_persistables (exe , dirname = param_path )
192228
193- @staticmethod
194- def _check_and_get_place (place ):
195- """
196- Check the type of place or get the default place
197- Args:
198- place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
199-
200- Raises:
201- TypeError if the type mismatched.
202-
203- Returns:
204- the original place if it is not None.
205- if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
206- Otherwise returns CPUPlace by default.
207- """
208- if place is None :
209- if core .is_compiled_with_cuda ():
210- return core .CUDAPlace (0 )
211- else :
212- return core .CPUPlace ()
213- else :
214- if not isinstance (place , core .CUDAPlace ) and not isinstance (
215- place , core .CPUPlace ):
216- raise TypeError ("Place should be either CUDAPlace or CPUPlace" )
217- return place
229+ def save_inference_model (self , model_path ):
230+ inference_program = framework .Program ()
231+ with framework .program_guard (inference_program ):
232+ with unique_name .guard ():
233+ predict_var = self .infer_func ()
234+ predict_var = self .train_program .block (0 ).var (predict_var .name )
235+ exe = executor .Executor (self .place )
236+ io .save_inference_model (model_path , [], [predict_var ], exe )
218237
219238 @contextlib .contextmanager
220239 def _prog_and_scope_guard (self ):
0 commit comments