-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Additional SOTA ingredients on Classification Recipe #4493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
063ca56
02b4d42
33a90f7
cfdeede
7ecc6d8
0563f9e
764fe02
19e7d49
d188ee0
a630986
6655dac
dc0edb9
e4a098f
2e93296
dadb2f5
6859fa2
8a9e1a8
17eaf48
950636e
9a6a443
2ce484a
e699eca
d861b33
9ee69c4
bc5a2bd
14a3323
c3c65d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,21 +9,22 @@ def __init__( | |
crop_size, | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225), | ||
interpolation=InterpolationMode.BILINEAR, | ||
hflip_prob=0.5, | ||
auto_augment_policy=None, | ||
random_erase_prob=0.0, | ||
): | ||
trans = [transforms.RandomResizedCrop(crop_size)] | ||
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] | ||
if hflip_prob > 0: | ||
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) | ||
if auto_augment_policy is not None: | ||
if auto_augment_policy == "ra": | ||
trans.append(autoaugment.RandAugment()) | ||
trans.append(autoaugment.RandAugment(interpolation=interpolation)) | ||
elif auto_augment_policy == "ta_wide": | ||
trans.append(autoaugment.TrivialAugmentWide()) | ||
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) | ||
else: | ||
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) | ||
trans.append(autoaugment.AutoAugment(policy=aa_policy)) | ||
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The change on interpolation here is non-BC but I consider this a bug rather than a previous feature. On the previous recipe there was a mismatch between the interpolation used for resizing and the one used for AA methods. |
||
trans.extend( | ||
[ | ||
transforms.PILToTensor(), | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -14,22 +14,20 @@ | |||
from torchvision.transforms.functional import InterpolationMode | ||||
|
||||
|
||||
def train_one_epoch( | ||||
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None | ||||
): | ||||
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though I'm not a huge fan of passing the whole There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, vision/references/classification/train.py Line 106 in e08c9e3
Concerning type hints/documentation, I think you are right. For some reason most of the string args don't define it. I've raised a new #4694 issue to improve it. |
||||
model.train() | ||||
metric_logger = utils.MetricLogger(delimiter=" ") | ||||
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) | ||||
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) | ||||
|
||||
header = "Epoch: [{}]".format(epoch) | ||||
for image, target in metric_logger.log_every(data_loader, print_freq, header): | ||||
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): | ||||
start_time = time.time() | ||||
image, target = image.to(device), target.to(device) | ||||
output = model(image) | ||||
|
||||
optimizer.zero_grad() | ||||
if amp: | ||||
if args.amp: | ||||
with torch.cuda.amp.autocast(): | ||||
loss = criterion(output, target) | ||||
scaler.scale(loss).backward() | ||||
|
@@ -40,16 +38,19 @@ def train_one_epoch( | |||
loss.backward() | ||||
optimizer.step() | ||||
|
||||
if model_ema and i % args.model_ema_steps == 0: | ||||
model_ema.update_parameters(model) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving EMA updates on per iterration level than on epoch. |
||||
if epoch < args.lr_warmup_epochs: | ||||
# Reset ema buffer to keep copying weights during warmup period | ||||
model_ema.n_averaged.fill_(0) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Always copy the weights during warmup. |
||||
|
||||
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) | ||||
batch_size = image.shape[0] | ||||
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) | ||||
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) | ||||
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) | ||||
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) | ||||
|
||||
if model_ema: | ||||
model_ema.update_parameters(model) | ||||
|
||||
|
||||
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): | ||||
model.eval() | ||||
|
@@ -106,24 +107,8 @@ def _get_cache_path(filepath): | |||
def load_data(traindir, valdir, args): | ||||
# Data loading code | ||||
print("Loading data") | ||||
resize_size, crop_size = 256, 224 | ||||
interpolation = InterpolationMode.BILINEAR | ||||
if args.model == "inception_v3": | ||||
resize_size, crop_size = 342, 299 | ||||
elif args.model.startswith("efficientnet_"): | ||||
sizes = { | ||||
"b0": (256, 224), | ||||
"b1": (256, 240), | ||||
"b2": (288, 288), | ||||
"b3": (320, 300), | ||||
"b4": (384, 380), | ||||
"b5": (456, 456), | ||||
"b6": (528, 528), | ||||
"b7": (600, 600), | ||||
} | ||||
e_type = args.model.replace("efficientnet_", "") | ||||
resize_size, crop_size = sizes[e_type] | ||||
interpolation = InterpolationMode.BICUBIC | ||||
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size | ||||
interpolation = InterpolationMode(args.interpolation) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove hardcoding of resize/crops based on model names. Instead use parameters. |
||||
|
||||
print("Loading training data") | ||||
st = time.time() | ||||
|
@@ -138,7 +123,10 @@ def load_data(traindir, valdir, args): | |||
dataset = torchvision.datasets.ImageFolder( | ||||
traindir, | ||||
presets.ClassificationPresetTrain( | ||||
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob | ||||
crop_size=train_crop_size, | ||||
interpolation=interpolation, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing the right |
||||
auto_augment_policy=auto_augment_policy, | ||||
random_erase_prob=random_erase_prob, | ||||
), | ||||
) | ||||
if args.cache_dataset: | ||||
|
@@ -156,7 +144,9 @@ def load_data(traindir, valdir, args): | |||
else: | ||||
dataset_test = torchvision.datasets.ImageFolder( | ||||
valdir, | ||||
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation), | ||||
presets.ClassificationPresetEval( | ||||
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation | ||||
), | ||||
) | ||||
if args.cache_dataset: | ||||
print("Saving dataset_test to {}".format(cache_path)) | ||||
|
@@ -224,26 +214,30 @@ def main(args): | |||
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) | ||||
|
||||
if args.norm_weight_decay is None: | ||||
parameters = model.parameters() | ||||
else: | ||||
param_groups = torchvision.ops._utils.split_normalization_params(model) | ||||
wd_groups = [args.norm_weight_decay, args.weight_decay] | ||||
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separate the BN normalization params from the rest so that we apply different weight decay. Improves by 0.1-0.2. |
||||
|
||||
opt_name = args.opt.lower() | ||||
if opt_name.startswith("sgd"): | ||||
optimizer = torch.optim.SGD( | ||||
model.parameters(), | ||||
parameters, | ||||
lr=args.lr, | ||||
momentum=args.momentum, | ||||
weight_decay=args.weight_decay, | ||||
nesterov="nesterov" in opt_name, | ||||
) | ||||
elif opt_name == "rmsprop": | ||||
optimizer = torch.optim.RMSprop( | ||||
model.parameters(), | ||||
lr=args.lr, | ||||
momentum=args.momentum, | ||||
weight_decay=args.weight_decay, | ||||
eps=0.0316, | ||||
alpha=0.9, | ||||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 | ||||
) | ||||
elif opt_name == "adamw": | ||||
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding AdamW necessary for training ViT. |
||||
else: | ||||
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) | ||||
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") | ||||
|
||||
scaler = torch.cuda.amp.GradScaler() if args.amp else None | ||||
|
||||
|
@@ -288,13 +282,23 @@ def main(args): | |||
|
||||
model_ema = None | ||||
if args.model_ema: | ||||
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) | ||||
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: | ||||
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 | ||||
# | ||||
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) | ||||
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: | ||||
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs | ||||
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs | ||||
alpha = 1.0 - args.model_ema_decay | ||||
alpha = min(1.0, alpha * adjust) | ||||
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parameterize EMA independently from epochs. |
||||
|
||||
if args.resume: | ||||
checkpoint = torch.load(args.resume, map_location="cpu") | ||||
model_without_ddp.load_state_dict(checkpoint["model"]) | ||||
optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||||
if not args.test_only: | ||||
optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quality of life improvement to avoid the super annoying error messages if you don't define all optimizer params during validation. |
||||
args.start_epoch = checkpoint["epoch"] + 1 | ||||
if model_ema: | ||||
model_ema.load_state_dict(checkpoint["model_ema"]) | ||||
|
@@ -303,18 +307,18 @@ def main(args): | |||
# We disable the cudnn benchmarking because it can noticeably affect the accuracy | ||||
torch.backends.cudnn.benchmark = False | ||||
torch.backends.cudnn.deterministic = True | ||||
|
||||
evaluate(model, criterion, data_loader_test, device=device) | ||||
if model_ema: | ||||
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") | ||||
else: | ||||
evaluate(model, criterion, data_loader_test, device=device) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Choose which model to validate depending on the flag provided. |
||||
return | ||||
|
||||
print("Start training") | ||||
start_time = time.time() | ||||
for epoch in range(args.start_epoch, args.epochs): | ||||
if args.distributed: | ||||
train_sampler.set_epoch(epoch) | ||||
train_one_epoch( | ||||
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler | ||||
) | ||||
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) | ||||
lr_scheduler.step() | ||||
evaluate(model, criterion, data_loader_test, device=device) | ||||
if model_ema: | ||||
|
@@ -362,6 +366,12 @@ def get_args_parser(add_help=True): | |||
help="weight decay (default: 1e-4)", | ||||
dest="weight_decay", | ||||
) | ||||
parser.add_argument( | ||||
"--norm-weight-decay", | ||||
default=None, | ||||
type=float, | ||||
help="weight decay for Normalization layers (default: None, same value as --wd)", | ||||
) | ||||
parser.add_argument( | ||||
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" | ||||
) | ||||
|
@@ -415,15 +425,33 @@ def get_args_parser(add_help=True): | |||
parser.add_argument( | ||||
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" | ||||
) | ||||
parser.add_argument( | ||||
"--model-ema-steps", | ||||
type=int, | ||||
default=32, | ||||
help="the number of iterations that controls how often to update the EMA model (default: 32)", | ||||
) | ||||
parser.add_argument( | ||||
"--model-ema-decay", | ||||
type=float, | ||||
default=0.9, | ||||
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", | ||||
default=0.99998, | ||||
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reconfiguring default value of EMA now that we do per iter instead of per epoch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n00b q: Is this default value 0.99998 used most often? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a good guess for ImageNet, considering the typical batch size for 8 gpus. The reason of changing this so drastically is because we switch from update per epoch to updates every X iters (X=32, configurable). |
||||
) | ||||
parser.add_argument( | ||||
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." | ||||
) | ||||
parser.add_argument( | ||||
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" | ||||
) | ||||
parser.add_argument( | ||||
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" | ||||
) | ||||
parser.add_argument( | ||||
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" | ||||
) | ||||
parser.add_argument( | ||||
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" | ||||
) | ||||
|
||||
return parser | ||||
|
||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we removed the hardcoding of parameters based on model names, we now need to provide extra parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing in my mind, not related to this PR, is that if we can also let users pass
kwargs
to the models through command line? (in addition to thetrain.py
arguments)For example, when I train the ViT model, training from scratch and fine-tuning require 2 different heads, in this case I want to configure the
representation_size
differently, and currently I need to manually change the python defaults to reflect this.wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will probably need to introduce more parameters to be able to do this. We will do it to enable your work but it's also part of the reason why the
ArgumentParser
is a poor solution. Hopefully this will be deprecated by the STL work you are preparing!