Skip to content

Additional SOTA ingredients on Classification Recipe #4493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
063ca56
Update EMA every X iters.
datumbox Sep 28, 2021
02b4d42
Adding AdamW optimizer.
datumbox Sep 28, 2021
33a90f7
Adjusting EMA decay scheme.
datumbox Sep 28, 2021
cfdeede
Support custom weight decay for Normalization layers.
datumbox Sep 28, 2021
7ecc6d8
Fix identation bug.
datumbox Sep 28, 2021
0563f9e
Change EMA adjustment.
datumbox Sep 29, 2021
764fe02
Merge branch 'main' into references/optimizations
datumbox Sep 30, 2021
19e7d49
Merge branch 'main' into references/optimizations
datumbox Oct 1, 2021
d188ee0
Quality of life changes to faciliate testing
datumbox Oct 4, 2021
a630986
Merge branch 'main' into references/optimizations
datumbox Oct 5, 2021
6655dac
ufmt format
datumbox Oct 5, 2021
dc0edb9
Fixing imports.
datumbox Oct 5, 2021
e4a098f
Merge branch 'main' into references/optimizations
datumbox Oct 7, 2021
2e93296
Adding FixRes improvement.
datumbox Oct 8, 2021
dadb2f5
Merge branch 'main' into references/optimizations
datumbox Oct 8, 2021
6859fa2
Support EMA in store_model_weights.
datumbox Oct 13, 2021
8a9e1a8
Merge branch 'main' into references/optimizations
datumbox Oct 13, 2021
17eaf48
Merge branch 'main' into references/optimizations
datumbox Oct 14, 2021
950636e
Adding interpolation values.
datumbox Oct 15, 2021
9a6a443
Change train_crop_size.
datumbox Oct 17, 2021
2ce484a
Merge branch 'main' into references/optimizations
datumbox Oct 17, 2021
e699eca
Add interpolation option.
datumbox Oct 17, 2021
d861b33
Merge branch 'main' into references/optimizations
datumbox Oct 21, 2021
9ee69c4
Removing hardcoded interpolation and sizes from the scripts.
datumbox Oct 21, 2021
bc5a2bd
Fixing linter.
datumbox Oct 21, 2021
14a3323
Incorporating feedback from code review.
datumbox Oct 21, 2021
c3c65d2
Merge branch 'main' into references/optimizations
datumbox Oct 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ def __init__(
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size)]
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
Expand Down
111 changes: 73 additions & 38 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
from torchvision.transforms.functional import InterpolationMode


def train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Copy link
Contributor

@yiwen-song yiwen-song Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I'm not a huge fan of passing the whole args to a single method (as it's not clear what are actually needed by this function), but I can see you do this just to reduce the number of arguments.
In the future we might want to add some type hints for all the args used in this script and also some documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, args is passed to reduce the number of parameters (merge 4 to 1). This is used in other places of the script such as, so I just use the same pattern:

def load_data(traindir, valdir, args):

Concerning type hints/documentation, I think you are right. For some reason most of the string args don't define it. I've raised a new #4694 issue to improve it.

model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

header = "Epoch: [{}]".format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)

optimizer.zero_grad()
if amp:
if args.amp:
with torch.cuda.amp.autocast():
loss = criterion(output, target)
scaler.scale(loss).backward()
Expand All @@ -39,16 +37,19 @@ def train_one_epoch(
loss.backward()
optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving EMA updates on per iterration level than on epoch.

if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always copy the weights during warmup.


acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))

if model_ema:
model_ema.update_parameters(model)


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval()
Expand Down Expand Up @@ -87,23 +88,25 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = 256, 224
val_resize_size, val_crop_size, train_crop_size = 256, 224, 224
interpolation = InterpolationMode.BILINEAR
if args.model == "inception_v3":
resize_size, crop_size = 342, 299
val_resize_size, val_crop_size, train_crop_size = 342, 299, 299
elif args.model == "resnet50":
val_resize_size, val_crop_size, train_crop_size = 256, 224, 192
elif args.model.startswith("efficientnet_"):
sizes = {
"b0": (256, 224),
"b1": (256, 240),
"b2": (288, 288),
"b3": (320, 300),
"b4": (384, 380),
"b5": (456, 456),
"b6": (528, 528),
"b7": (600, 600),
"b0": (256, 224, 224),
"b1": (256, 240, 240),
"b2": (288, 288, 288),
"b3": (320, 300, 300),
"b4": (384, 380, 380),
"b5": (456, 456, 456),
"b6": (528, 528, 528),
"b7": (600, 600, 600),
}
e_type = args.model.replace("efficientnet_", "")
resize_size, crop_size = sizes[e_type]
val_resize_size, val_crop_size, train_crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC

print("Loading training data")
Expand All @@ -119,7 +122,10 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
crop_size=train_crop_size,
interpolation=interpolation,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the right interpolation value fixes the discrepancy bug discussed above.

auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
),
)
if args.cache_dataset:
Expand All @@ -137,7 +143,9 @@ def load_data(traindir, valdir, args):
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
Expand Down Expand Up @@ -201,26 +209,30 @@ def main(args):

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = model.parameters()
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate the BN normalization params from the rest so that we apply different weight decay. Improves by 0.1-0.2.


opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
model.parameters(),
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=0.0316,
alpha=0.9,
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding AdamW necessary for training ViT.

else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

Expand Down Expand Up @@ -265,29 +277,40 @@ def main(args):

model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameterize EMA independently from epochs.


if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quality of life improvement to avoid the super annoying error messages if you don't define all optimizer params during validation.

args.start_epoch = checkpoint["epoch"] + 1
if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"])

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose which model to validate depending on the flag provided.

return

print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
Expand Down Expand Up @@ -335,6 +358,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
)
Expand Down Expand Up @@ -388,11 +417,17 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
)
parser.add_argument(
"--model-ema-steps",
type=int,
default=32,
help="the number of iterations that controls how often to update the EMA model (default: 32)",
)
parser.add_argument(
"--model-ema-decay",
type=float,
default=0.9,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
default=0.99998,
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reconfiguring default value of EMA now that we do per iter instead of per epoch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: Is this default value 0.99998 used most often?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good guess for ImageNet, considering the typical batch size for 8 gpus. The reason of changing this so drastically is because we switch from update per epoch to updates every X iters (X=32, configurable).

)

return parser
Expand Down
14 changes: 12 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import torch
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image
from torch import Tensor
from torch import nn, Tensor
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import ops
from torchvision import models, ops


class RoIOpTester(ABC):
Expand Down Expand Up @@ -1176,5 +1176,15 @@ def test_stochastic_depth(self, mode, p):
assert p_value > 0.0001


class TestUtils:
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
def test_split_normalization_params(self, norm_layer):
model = models.mobilenet_v3_large(norm_layer=norm_layer)
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

assert len(params[0]) == 92
assert len(params[1]) == 82


if __name__ == "__main__":
pytest.main([__file__])
29 changes: 27 additions & 2 deletions torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Union
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch import nn, Tensor


def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
Expand Down Expand Up @@ -36,3 +36,28 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
else:
assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]"
return


def split_normalization_params(
model: nn.Module, norm_classes: Optional[List[type]] = None
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")

classes = tuple(norm_classes)

norm_params = []
other_params = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params