@@ -18,8 +18,9 @@ class Optimizer(object):
1818 but need to use one of it's implementation.
1919 """
2020
21- def __init__ (self , global_step = None ):
21+ def __init__ (self , global_step = None , regularization = None ):
2222 self ._global_step = global_step
23+ self .regularization = regularization
2324 # Dictionary of accumulators. Some optimizer subclasses need to
2425 # allocate and manage extra variables associated with the parameters
2526 # to train. These variables are called accumulators.
@@ -199,7 +200,8 @@ def minimize(self,
199200 """
200201 params_grads = append_backward_ops (loss , parameter_list , no_grad_set )
201202 # Add regularization if any
202- params_grads = append_regularization_ops (params_grads )
203+ params_grads = append_regularization_ops (params_grads ,
204+ self .regularization )
203205 optimize_ops = self .create_optimization_pass (params_grads , loss ,
204206 startup_program )
205207 return optimize_ops
@@ -209,9 +211,9 @@ class SGDOptimizer(Optimizer):
209211 """ Simple SGD optimizer without any state.
210212 """
211213
212- def __init__ (self , learning_rate , global_step = None ):
214+ def __init__ (self , learning_rate , ** kwargs ):
213215 assert learning_rate is not None
214- super (SGDOptimizer , self ).__init__ (global_step )
216+ super (SGDOptimizer , self ).__init__ (** kwargs )
215217 self .type = "sgd"
216218 self ._learning_rate = learning_rate
217219
@@ -236,14 +238,10 @@ class MomentumOptimizer(Optimizer):
236238 """
237239 _velocity_acc_str = "velocity"
238240
239- def __init__ (self ,
240- learning_rate ,
241- momentum ,
242- use_nesterov = False ,
243- global_step = None ):
241+ def __init__ (self , learning_rate , momentum , use_nesterov = False , ** kwargs ):
244242 assert learning_rate is not None
245243 assert momentum is not None
246- super (MomentumOptimizer , self ).__init__ (global_step )
244+ super (MomentumOptimizer , self ).__init__ (** kwargs )
247245 self .type = "momentum"
248246 self ._learning_rate = learning_rate
249247 self ._momentum = momentum
@@ -284,10 +282,10 @@ class AdagradOptimizer(Optimizer):
284282 """
285283 _moment_acc_str = "moment"
286284
287- def __init__ (self , learning_rate , epsilon = 1.0e-6 , global_step = None ):
285+ def __init__ (self , learning_rate , epsilon = 1.0e-6 , ** kwargs ):
288286 assert learning_rate is not None
289287 assert epsilon is not None
290- super (AdagradOptimizer , self ).__init__ (global_step )
288+ super (AdagradOptimizer , self ).__init__ (** kwargs )
291289 self .type = "adagrad"
292290 self ._learning_rate = learning_rate
293291 self ._epsilon = epsilon
@@ -331,12 +329,12 @@ def __init__(self,
331329 beta1 = 0.9 ,
332330 beta2 = 0.999 ,
333331 epsilon = 1e-8 ,
334- global_step = None ):
332+ ** kwargs ):
335333 assert learning_rate is not None
336334 assert beta1 is not None
337335 assert beta2 is not None
338336 assert epsilon is not None
339- super (AdamOptimizer , self ).__init__ (global_step )
337+ super (AdamOptimizer , self ).__init__ (** kwargs )
340338 self .type = "adam"
341339 self ._learning_rate = learning_rate
342340 self ._beta1 = beta1
@@ -436,12 +434,12 @@ def __init__(self,
436434 beta1 = 0.9 ,
437435 beta2 = 0.999 ,
438436 epsilon = 1e-8 ,
439- global_step = None ):
437+ ** kwargs ):
440438 assert learning_rate is not None
441439 assert beta1 is not None
442440 assert beta2 is not None
443441 assert epsilon is not None
444- super (AdamaxOptimizer , self ).__init__ ()
442+ super (AdamaxOptimizer , self ).__init__ (** kwargs )
445443 self .type = "adamax"
446444 self ._learning_rate = learning_rate
447445 self ._beta1 = beta1
@@ -514,16 +512,12 @@ class DecayedAdagradOptimizer(Optimizer):
514512 """
515513 _moment_acc_str = "moment"
516514
517- def __init__ (self ,
518- learning_rate ,
519- decay = 0.95 ,
520- epsilon = 1.0e-6 ,
521- global_step = None ):
515+ def __init__ (self , learning_rate , decay = 0.95 , epsilon = 1.0e-6 , ** kwargs ):
522516 assert learning_rate is not None
523517 assert decay is not None
524518 assert epsilon is not None
525519
526- super (DecayedAdagradOptimizer , self ).__init__ (global_step )
520+ super (DecayedAdagradOptimizer , self ).__init__ (** kwargs )
527521 self .type = "decayed_adagrad"
528522 self ._learning_rate = learning_rate
529523 self ._decay = decay
0 commit comments