@@ -208,45 +208,42 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str]
208
208
for var_name , var_vals in var2vals .items ():
209
209
if isinstance (var_vals [0 ], str ):
210
210
continue
211
+
211
212
metric = {}
212
213
n_resps = len (var_vals )
213
214
metric [f"mean@{ n_resps } " ] = np .mean (var_vals )
214
- metric [f"std@{ n_resps } " ] = np .std (var_vals )
215
-
216
- ns = []
217
- n = 2
218
- while n < n_resps :
219
- ns .append (n )
220
- n *= 2
221
- ns .append (n_resps )
222
215
223
- # If there are multiple responses, we can compute the best/worst-of-N metrics
224
- # If not, they are the same as the single response metrics
225
216
if n_resps > 1 :
217
+ # n = n_resps
218
+ metric [f"std@{ n_resps } " ] = np .std (var_vals )
219
+
220
+ metric [f"best@{ n_resps } /mean" ] = np .max (var_vals )
221
+ metric [f"worst@{ n_resps } /mean" ] = np .min (var_vals )
222
+ if var2vals .get ("pred" , None ) is not None :
223
+ vote_data = [{"val" : val , "pred" : pred } for val , pred in zip (var_vals , var2vals ["pred" ])]
224
+ metric [f"maj@{ n_resps } /mean" ] = calc_maj_val (vote_data , vote_key = "pred" , val_key = "val" )
225
+ # 1 < n < n_resps
226
+ ns = []
227
+ n = 2
228
+ while n < n_resps :
229
+ ns .append (n )
230
+ n *= 2
231
+
226
232
for n in ns :
227
- if n == n_resps :
228
- # Non-bootstrapped
229
- metric [f"best@{ n } /mean" ] = np .max (var_vals )
230
- metric [f"worst@{ n } /mean" ] = np .min (var_vals )
231
- if var2vals .get ("pred" , None ) is not None :
232
- vote_data = [{"val" : val , "pred" : pred } for val , pred in zip (var_vals , var2vals ["pred" ])]
233
- metric [f"maj@{ n } /mean" ] = calc_maj_val (vote_data , vote_key = "pred" , val_key = "val" )
234
- else :
235
- # Bootstrapped
236
- [(bon_mean , bon_std ), (won_mean , won_std )] = bootstrap_metric (data = var_vals , subset_size = n , reduce_fns = [np .max , np .min ], seed = seed )
237
- metric [f"best@{ n } /mean" ], metric [f"best@{ n } /std" ] = bon_mean , bon_std
238
- metric [f"worst@{ n } /mean" ], metric [f"worst@{ n } /std" ] = won_mean , won_std
239
- if var2vals .get ("pred" , None ) is not None :
240
- vote_data = [{"val" : val , "pred" : pred } for val , pred in zip (var_vals , var2vals ["pred" ])]
241
- [(maj_n_mean , maj_n_std )] = bootstrap_metric (
242
- data = vote_data ,
243
- subset_size = n ,
244
- reduce_fns = [partial (calc_maj_val , vote_key = "pred" , val_key = "val" )],
245
- seed = seed ,
246
- )
247
- metric [f"maj@{ n } /mean" ], metric [f"maj@{ n } /std" ] = maj_n_mean , maj_n_std
248
-
249
- data_src2prompt2var2metric [data_source ][prompt ][var_name ] = metric
233
+ [(bon_mean , bon_std ), (won_mean , won_std )] = bootstrap_metric (data = var_vals , subset_size = n , reduce_fns = [np .max , np .min ], seed = seed )
234
+ metric [f"best@{ n } /mean" ], metric [f"best@{ n } /std" ] = bon_mean , bon_std
235
+ metric [f"worst@{ n } /mean" ], metric [f"worst@{ n } /std" ] = won_mean , won_std
236
+ if var2vals .get ("pred" , None ) is not None :
237
+ vote_data = [{"val" : val , "pred" : pred } for val , pred in zip (var_vals , var2vals ["pred" ])]
238
+ [(maj_n_mean , maj_n_std )] = bootstrap_metric (
239
+ data = vote_data ,
240
+ subset_size = n ,
241
+ reduce_fns = [partial (calc_maj_val , vote_key = "pred" , val_key = "val" )],
242
+ seed = seed ,
243
+ )
244
+ metric [f"maj@{ n } /mean" ], metric [f"maj@{ n } /std" ] = maj_n_mean , maj_n_std
245
+
246
+ data_src2prompt2var2metric [data_source ][prompt ][var_name ] = metric
250
247
251
248
# Aggregate metrics across prompts
252
249
data_src2var2metric2prompt_vals = defaultdict (lambda : defaultdict (lambda : defaultdict (list )))
0 commit comments