|
25 | 25 | from contextlib import suppress
|
26 | 26 | from datetime import datetime
|
27 | 27 | import matplotlib.pyplot as plt
|
28 |
| - |
| 28 | +from fvcore.nn import FlopCountAnalysis,flop_count_table |
29 | 29 |
|
30 | 30 | import pickle
|
31 | 31 | import torch
|
|
208 | 208 | parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
|
209 | 209 | parser.add_argument('--netmode', type=int, default=1, help='which stride mode to use(1 to 5)')
|
210 | 210 |
|
| 211 | +parser.add_argument('--freeze-top', action='store_true', default=False, help='Freeze early layers up to idx36') |
| 212 | +parser.add_argument('--freeze-bot', action='store_true', default=False, help='Freeze late layers from idx36 to the end') |
| 213 | +parser.add_argument('--debug', action='store_true', default=False, help='whether to enable set_detect_anomaly() or not') |
| 214 | +parser.add_argument('--use-avgs', default='', type=str, metavar='PATH',help='Resume full model and optimizer state from checkpoint (default: none)') |
| 215 | + |
211 | 216 | # torch.autograd.set_detect_anomaly(True)
|
212 | 217 | def _parse_args():
|
213 | 218 | # Do we have a config file to parse?
|
@@ -304,6 +309,9 @@ def main():
|
304 | 309 | if args.fuser:
|
305 | 310 | set_jit_fuser(args.fuser)
|
306 | 311 |
|
| 312 | + if args.debug: |
| 313 | + torch.autograd.set_detect_anomaly(True) |
| 314 | + |
307 | 315 | # convert into int
|
308 | 316 | args.drop_rates = {int(key):float(value) for key,value in args.drop_rates.items()}
|
309 | 317 | print(f'args.drop_rates: {args.drop_rates}')
|
@@ -334,7 +342,10 @@ def main():
|
334 | 342 | model.set_grad_checkpointing(enable=True)
|
335 | 343 |
|
336 | 344 | if args.local_rank == 0:
|
| 345 | + flops = FlopCountAnalysis(model, torch.randn(size=(1,3,224,224))) |
337 | 346 | _logger.info(f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()]):,}')
|
| 347 | + _logger.info(f'Model {safe_model_name(args.model)} created, FLOPS count:{flops.total():,}') |
| 348 | + # _logger.info(f'Model {safe_model_name(args.model)} Flops table\n{flop_count_table(flops)}') |
338 | 349 |
|
339 | 350 | _logger.info(f'Model: {model}')
|
340 | 351 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
@@ -570,6 +581,25 @@ def main():
|
570 | 581 | f.write(args_text)
|
571 | 582 |
|
572 | 583 | try:
|
| 584 | + # freeze top layers |
| 585 | + if args.freeze_bot: |
| 586 | + for n,v in model.features.named_parameters(): |
| 587 | + if int(n.split('.')[0]) <=36: |
| 588 | + v.requires_grad=False |
| 589 | + elif args.freeze_top: |
| 590 | + for n,v in model.features.named_parameters(): |
| 591 | + if int(n.split('.')[0]) >=36: |
| 592 | + v.requires_grad=False |
| 593 | + |
| 594 | + if args.freeze_top or args.freeze_bot: |
| 595 | + for n,v in model.features.named_parameters(): |
| 596 | + print(f'{n}.requires_grad: {v.requires_grad}') |
| 597 | + |
| 598 | + if args.use_avgs: |
| 599 | + checkpoint_avgs = torch.load(args.use_avgs,map_location='cuda') |
| 600 | + model.load_state_dict(checkpoint_avgs) |
| 601 | + print(f'avg model loaded') |
| 602 | + |
573 | 603 | for epoch in range(start_epoch, num_epochs):
|
574 | 604 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
575 | 605 | loader_train.sampler.set_epoch(epoch)
|
|
0 commit comments