Skip to content

Commit 0842093

Browse files
committed
update train.py
1 parent 15af4fc commit 0842093

File tree

1 file changed

+31
-1
lines changed
  • ImageNet/training_scripts/imagenet_training

1 file changed

+31
-1
lines changed

ImageNet/training_scripts/imagenet_training/train.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from contextlib import suppress
2626
from datetime import datetime
2727
import matplotlib.pyplot as plt
28-
28+
from fvcore.nn import FlopCountAnalysis,flop_count_table
2929

3030
import pickle
3131
import torch
@@ -208,6 +208,11 @@
208208
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
209209
parser.add_argument('--netmode', type=int, default=1, help='which stride mode to use(1 to 5)')
210210

211+
parser.add_argument('--freeze-top', action='store_true', default=False, help='Freeze early layers up to idx36')
212+
parser.add_argument('--freeze-bot', action='store_true', default=False, help='Freeze late layers from idx36 to the end')
213+
parser.add_argument('--debug', action='store_true', default=False, help='whether to enable set_detect_anomaly() or not')
214+
parser.add_argument('--use-avgs', default='', type=str, metavar='PATH',help='Resume full model and optimizer state from checkpoint (default: none)')
215+
211216
# torch.autograd.set_detect_anomaly(True)
212217
def _parse_args():
213218
# Do we have a config file to parse?
@@ -304,6 +309,9 @@ def main():
304309
if args.fuser:
305310
set_jit_fuser(args.fuser)
306311

312+
if args.debug:
313+
torch.autograd.set_detect_anomaly(True)
314+
307315
# convert into int
308316
args.drop_rates = {int(key):float(value) for key,value in args.drop_rates.items()}
309317
print(f'args.drop_rates: {args.drop_rates}')
@@ -334,7 +342,10 @@ def main():
334342
model.set_grad_checkpointing(enable=True)
335343

336344
if args.local_rank == 0:
345+
flops = FlopCountAnalysis(model, torch.randn(size=(1,3,224,224)))
337346
_logger.info(f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()]):,}')
347+
_logger.info(f'Model {safe_model_name(args.model)} created, FLOPS count:{flops.total():,}')
348+
# _logger.info(f'Model {safe_model_name(args.model)} Flops table\n{flop_count_table(flops)}')
338349

339350
_logger.info(f'Model: {model}')
340351
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
@@ -570,6 +581,25 @@ def main():
570581
f.write(args_text)
571582

572583
try:
584+
# freeze top layers
585+
if args.freeze_bot:
586+
for n,v in model.features.named_parameters():
587+
if int(n.split('.')[0]) <=36:
588+
v.requires_grad=False
589+
elif args.freeze_top:
590+
for n,v in model.features.named_parameters():
591+
if int(n.split('.')[0]) >=36:
592+
v.requires_grad=False
593+
594+
if args.freeze_top or args.freeze_bot:
595+
for n,v in model.features.named_parameters():
596+
print(f'{n}.requires_grad: {v.requires_grad}')
597+
598+
if args.use_avgs:
599+
checkpoint_avgs = torch.load(args.use_avgs,map_location='cuda')
600+
model.load_state_dict(checkpoint_avgs)
601+
print(f'avg model loaded')
602+
573603
for epoch in range(start_epoch, num_epochs):
574604
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
575605
loader_train.sampler.set_epoch(epoch)

0 commit comments

Comments
 (0)