@@ -348,14 +348,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
348
348
lines .append ("" )
349
349
350
350
351
- def generate_weights_table (module , table_name , metrics , include_pattern = None , exclude_pattern = None ):
351
+ def generate_weights_table (module , table_name , metrics , include_patterns = None , exclude_patterns = None ):
352
352
weight_enums = [getattr (module , name ) for name in dir (module ) if name .endswith ("_Weights" )]
353
353
weights = [w for weight_enum in weight_enums for w in weight_enum ]
354
354
355
- if include_pattern is not None :
356
- weights = [w for w in weights if include_pattern in str (w )]
357
- if exclude_pattern is not None :
358
- weights = [w for w in weights if exclude_pattern not in str (w )]
355
+ if include_patterns is not None :
356
+ weights = [w for w in weights if any ( p in str (w ) for p in include_patterns )]
357
+ if exclude_patterns is not None :
358
+ weights = [w for w in weights if all ( p not in str (w ) for p in exclude_patterns )]
359
359
360
360
metrics_keys , metrics_names = zip (* metrics )
361
361
column_names = ["Weight" ] + list (metrics_names ) + ["Params" , "Recipe" ]
@@ -383,13 +383,19 @@ def generate_weights_table(module, table_name, metrics, include_pattern=None, ex
383
383
384
384
generate_weights_table (module = M , table_name = "classification" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
385
385
generate_weights_table (
386
- module = M .detection , table_name = "detection" , metrics = [("box_map" , "Box MAP" )], exclude_pattern = "Keypoint"
386
+ module = M .detection , table_name = "detection" , metrics = [("box_map" , "Box MAP" )], exclude_patterns = ["Mask" , "Keypoint" ]
387
+ )
388
+ generate_weights_table (
389
+ module = M .detection ,
390
+ table_name = "instance_segmentation" ,
391
+ metrics = [("box_map" , "Box MAP" ), ("mask_map" , "Mask MAP" )],
392
+ include_patterns = ["Mask" ],
387
393
)
388
394
generate_weights_table (
389
395
module = M .detection ,
390
396
table_name = "detection_keypoint" ,
391
397
metrics = [("box_map" , "Box MAP" ), ("kp_map" , "Keypoint MAP" )],
392
- include_pattern = "Keypoint" ,
398
+ include_patterns = [ "Keypoint" ] ,
393
399
)
394
400
generate_weights_table (
395
401
module = M .segmentation , table_name = "segmentation" , metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )]
0 commit comments