@@ -181,13 +181,18 @@ def __init__(
181181 self .use_delta_updates = use_delta_updates
182182
183183 self .model = None # Initialize the model attribute to None
184- if self .persistent_db and self ._recover ():
185- logger .info ("recovered state of aggregator" )
186184
187- # The model is built by recovery if at least one round has finished
188- if self .model :
189- logger .info ("Model was loaded by recovery" )
190- elif initial_tensor_dict :
185+ # Callbacks
186+ self .callbacks = callbacks_module .CallbackList (
187+ callbacks ,
188+ add_memory_profiler = log_memory_usage ,
189+ add_metric_writer = write_logs ,
190+ origin = "aggregator" ,
191+ )
192+
193+ self .collaborator_tensor_results = {} # {TensorKey: nparray}}
194+
195+ if initial_tensor_dict :
191196 self ._load_initial_tensors_from_dict (initial_tensor_dict )
192197 self .model = utils .construct_model_proto (
193198 tensor_dict = initial_tensor_dict ,
@@ -198,15 +203,8 @@ def __init__(
198203 self .model : base_pb2 .ModelProto = utils .load_proto (self .init_state_path )
199204 self ._load_initial_tensors () # keys are TensorKeys
200205
201- self .collaborator_tensor_results = {} # {TensorKey: nparray}}
202-
203- # Callbacks
204- self .callbacks = callbacks_module .CallbackList (
205- callbacks ,
206- add_memory_profiler = log_memory_usage ,
207- add_metric_writer = write_logs ,
208- origin = "aggregator" ,
209- )
206+ if self .persistent_db and self ._recover ():
207+ logger .info ("Recovered state of aggregator" )
210208
211209 # TODO: Aggregator has no concrete notion of round_begin.
212210 # https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537
0 commit comments