@@ -168,20 +168,19 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True
168168 if "pixel_threshold_class" in state_dict :
169169 self .pixel_threshold = self ._get_instance (state_dict , "pixel_threshold_class" )
170170
171- if "anomaly_maps_normalization_class" in state_dict :
172- self .anomaly_maps_normalization_metrics = self ._get_instance (state_dict , "anomaly_maps_normalization_class" )
173- if "box_scores_normalization_class" in state_dict :
174- self .box_scores_normalization_metrics = self ._get_instance (state_dict , "box_scores_normalization_class" )
171+ # check only for pred score normalization metrics, because if this one is present, all others are too
175172 if "pred_scores_normalization_class" in state_dict :
173+ self .box_scores_normalization_metrics = self ._get_instance (state_dict , "box_scores_normalization_class" )
174+ self .anomaly_maps_normalization_metrics = self ._get_instance (state_dict , "anomaly_maps_normalization_class" )
176175 self .pred_scores_normalization_metrics = self ._get_instance (state_dict , "pred_scores_normalization_class" )
177176
178- self .normalization_metrics = MetricCollection (
179- {
180- "anomaly_maps" : self .anomaly_maps_normalization_metrics ,
181- "box_scores" : self .box_scores_normalization_metrics ,
182- "pred_scores" : self .pred_scores_normalization_metrics ,
183- },
184- )
177+ self .normalization_metrics = MetricCollection (
178+ {
179+ "anomaly_maps" : self .anomaly_maps_normalization_metrics ,
180+ "box_scores" : self .box_scores_normalization_metrics ,
181+ "pred_scores" : self .pred_scores_normalization_metrics ,
182+ },
183+ )
185184 # Used to load metrics if there is any related data in state_dict
186185 self ._load_metrics (state_dict )
187186
0 commit comments