-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Additional SOTA ingredients on Classification Recipe #4493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
063ca56
02b4d42
33a90f7
cfdeede
7ecc6d8
0563f9e
764fe02
19e7d49
d188ee0
a630986
6655dac
dc0edb9
e4a098f
2e93296
dadb2f5
6859fa2
8a9e1a8
17eaf48
950636e
9a6a443
2ce484a
e699eca
d861b33
9ee69c4
bc5a2bd
14a3323
c3c65d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,22 +13,20 @@ | |
from torchvision.transforms.functional import InterpolationMode | ||
|
||
|
||
def train_one_epoch( | ||
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None | ||
): | ||
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): | ||
model.train() | ||
metric_logger = utils.MetricLogger(delimiter=" ") | ||
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) | ||
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) | ||
|
||
header = "Epoch: [{}]".format(epoch) | ||
for image, target in metric_logger.log_every(data_loader, print_freq, header): | ||
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): | ||
start_time = time.time() | ||
image, target = image.to(device), target.to(device) | ||
output = model(image) | ||
|
||
optimizer.zero_grad() | ||
if amp: | ||
if args.amp: | ||
with torch.cuda.amp.autocast(): | ||
loss = criterion(output, target) | ||
scaler.scale(loss).backward() | ||
|
@@ -39,16 +37,19 @@ def train_one_epoch( | |
loss.backward() | ||
optimizer.step() | ||
|
||
if model_ema and i % args.model_ema_steps == 0: | ||
model_ema.update_parameters(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving EMA updates on per iterration level than on epoch. |
||
if epoch < args.lr_warmup_epochs: | ||
# Reset ema buffer to keep copying weights during warmup period | ||
model_ema.n_averaged.fill_(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Always copy the weights during warmup. |
||
|
||
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) | ||
batch_size = image.shape[0] | ||
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) | ||
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) | ||
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) | ||
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) | ||
|
||
if model_ema: | ||
model_ema.update_parameters(model) | ||
|
||
|
||
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): | ||
model.eval() | ||
|
@@ -87,23 +88,25 @@ def _get_cache_path(filepath): | |
def load_data(traindir, valdir, args): | ||
# Data loading code | ||
print("Loading data") | ||
resize_size, crop_size = 256, 224 | ||
val_resize_size, val_crop_size, train_crop_size = 256, 224, 224 | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
interpolation = InterpolationMode.BILINEAR | ||
if args.model == "inception_v3": | ||
resize_size, crop_size = 342, 299 | ||
val_resize_size, val_crop_size, train_crop_size = 342, 299, 299 | ||
elif args.model == "resnet50": | ||
val_resize_size, val_crop_size, train_crop_size = 256, 224, 192 | ||
elif args.model.startswith("efficientnet_"): | ||
sizes = { | ||
"b0": (256, 224), | ||
"b1": (256, 240), | ||
"b2": (288, 288), | ||
"b3": (320, 300), | ||
"b4": (384, 380), | ||
"b5": (456, 456), | ||
"b6": (528, 528), | ||
"b7": (600, 600), | ||
"b0": (256, 224, 224), | ||
"b1": (256, 240, 240), | ||
"b2": (288, 288, 288), | ||
"b3": (320, 300, 300), | ||
"b4": (384, 380, 380), | ||
"b5": (456, 456, 456), | ||
"b6": (528, 528, 528), | ||
"b7": (600, 600, 600), | ||
} | ||
e_type = args.model.replace("efficientnet_", "") | ||
resize_size, crop_size = sizes[e_type] | ||
val_resize_size, val_crop_size, train_crop_size = sizes[e_type] | ||
interpolation = InterpolationMode.BICUBIC | ||
|
||
print("Loading training data") | ||
|
@@ -119,7 +122,10 @@ def load_data(traindir, valdir, args): | |
dataset = torchvision.datasets.ImageFolder( | ||
traindir, | ||
presets.ClassificationPresetTrain( | ||
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob | ||
crop_size=train_crop_size, | ||
interpolation=interpolation, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing the right |
||
auto_augment_policy=auto_augment_policy, | ||
random_erase_prob=random_erase_prob, | ||
), | ||
) | ||
if args.cache_dataset: | ||
|
@@ -137,7 +143,9 @@ def load_data(traindir, valdir, args): | |
else: | ||
dataset_test = torchvision.datasets.ImageFolder( | ||
valdir, | ||
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation), | ||
presets.ClassificationPresetEval( | ||
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation | ||
), | ||
) | ||
if args.cache_dataset: | ||
print("Saving dataset_test to {}".format(cache_path)) | ||
|
@@ -201,26 +209,30 @@ def main(args): | |
|
||
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) | ||
|
||
if args.norm_weight_decay is None: | ||
parameters = model.parameters() | ||
else: | ||
param_groups = torchvision.ops._utils.split_normalization_params(model) | ||
wd_groups = [args.norm_weight_decay, args.weight_decay] | ||
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separate the BN normalization params from the rest so that we apply different weight decay. Improves by 0.1-0.2. |
||
|
||
opt_name = args.opt.lower() | ||
if opt_name.startswith("sgd"): | ||
optimizer = torch.optim.SGD( | ||
model.parameters(), | ||
parameters, | ||
lr=args.lr, | ||
momentum=args.momentum, | ||
weight_decay=args.weight_decay, | ||
nesterov="nesterov" in opt_name, | ||
) | ||
elif opt_name == "rmsprop": | ||
optimizer = torch.optim.RMSprop( | ||
model.parameters(), | ||
lr=args.lr, | ||
momentum=args.momentum, | ||
weight_decay=args.weight_decay, | ||
eps=0.0316, | ||
alpha=0.9, | ||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 | ||
) | ||
elif opt_name == "adamw": | ||
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding AdamW necessary for training ViT. |
||
else: | ||
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) | ||
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") | ||
|
||
scaler = torch.cuda.amp.GradScaler() if args.amp else None | ||
|
||
|
@@ -265,29 +277,40 @@ def main(args): | |
|
||
model_ema = None | ||
if args.model_ema: | ||
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) | ||
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: | ||
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 | ||
# | ||
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) | ||
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: | ||
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs | ||
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs | ||
alpha = 1.0 - args.model_ema_decay | ||
alpha = min(1.0, alpha * adjust) | ||
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parameterize EMA independently from epochs. |
||
|
||
if args.resume: | ||
checkpoint = torch.load(args.resume, map_location="cpu") | ||
model_without_ddp.load_state_dict(checkpoint["model"]) | ||
optimizer.load_state_dict(checkpoint["optimizer"]) | ||
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
if not args.test_only: | ||
optimizer.load_state_dict(checkpoint["optimizer"]) | ||
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quality of life improvement to avoid the super annoying error messages if you don't define all optimizer params during validation. |
||
args.start_epoch = checkpoint["epoch"] + 1 | ||
if model_ema: | ||
model_ema.load_state_dict(checkpoint["model_ema"]) | ||
|
||
if args.test_only: | ||
evaluate(model, criterion, data_loader_test, device=device) | ||
if model_ema: | ||
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") | ||
else: | ||
evaluate(model, criterion, data_loader_test, device=device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Choose which model to validate depending on the flag provided. |
||
return | ||
|
||
print("Start training") | ||
start_time = time.time() | ||
for epoch in range(args.start_epoch, args.epochs): | ||
if args.distributed: | ||
train_sampler.set_epoch(epoch) | ||
train_one_epoch( | ||
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler | ||
) | ||
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) | ||
lr_scheduler.step() | ||
evaluate(model, criterion, data_loader_test, device=device) | ||
if model_ema: | ||
|
@@ -335,6 +358,12 @@ def get_args_parser(add_help=True): | |
help="weight decay (default: 1e-4)", | ||
dest="weight_decay", | ||
) | ||
parser.add_argument( | ||
"--norm-weight-decay", | ||
default=None, | ||
type=float, | ||
help="weight decay for Normalization layers (default: None, same value as --wd)", | ||
) | ||
parser.add_argument( | ||
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" | ||
) | ||
|
@@ -388,11 +417,17 @@ def get_args_parser(add_help=True): | |
parser.add_argument( | ||
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" | ||
) | ||
parser.add_argument( | ||
"--model-ema-steps", | ||
type=int, | ||
default=32, | ||
help="the number of iterations that controls how often to update the EMA model (default: 32)", | ||
) | ||
parser.add_argument( | ||
"--model-ema-decay", | ||
type=float, | ||
default=0.9, | ||
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", | ||
default=0.99998, | ||
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reconfiguring default value of EMA now that we do per iter instead of per epoch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n00b q: Is this default value 0.99998 used most often? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a good guess for ImageNet, considering the typical batch size for 8 gpus. The reason of changing this so drastically is because we switch from update per epoch to updates every X iters (X=32, configurable). |
||
) | ||
|
||
return parser | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I'm not a huge fan of passing the whole
args
to a single method (as it's not clear what are actually needed by this function), but I can see you do this just to reduce the number of arguments.In the future we might want to add some type hints for all the args used in this script and also some documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed,
args
is passed to reduce the number of parameters (merge 4 to 1). This is used in other places of the script such as, so I just use the same pattern:vision/references/classification/train.py
Line 106 in e08c9e3
Concerning type hints/documentation, I think you are right. For some reason most of the string args don't define it. I've raised a new #4694 issue to improve it.