Skip to content

RAFT training reference Improvement #5590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 15, 2022
122 changes: 79 additions & 43 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):


@torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.

We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size
device = torch.device(args.device)

model.eval()

sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)

val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
Expand All @@ -88,7 +93,7 @@ def inner_loop(blob):
image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]

image1, image2 = image1.cuda(), image2.cuda()
image1, image2 = image1.to(device), image2.to(device)

padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)
Expand All @@ -115,21 +120,22 @@ def inner_loop(blob):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size

num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.distributed:
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])

if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
logger.synchronize_between_processes()

logger.synchronize_between_processes()
print(header, logger)


def validate(model, args):
def evaluate(model, args):
val_datasets = args.val_dataset or []

if args.prototype:
Expand All @@ -145,21 +151,21 @@ def validate(model, args):
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and args.rank == 0:
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)

val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_validate(
_evaluate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_validate(
_evaluate(
model,
args,
val_dataset,
Expand All @@ -172,11 +178,12 @@ def validate(model, args):


def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
device = torch.device(args.device)
for data_blob in logger.log_every(train_loader):

optimizer.zero_grad()

image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)

loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
Expand All @@ -200,36 +207,68 @@ def main(args):
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
utils.setup_ddp(args)

if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)

if args.prototype:
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)

model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
if args.distributed:
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model

if args.resume is not None:
d = torch.load(args.resume, map_location="cpu")
model.load_state_dict(d, strict=True)
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])

if args.train_dataset is None:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
validate(model, args)
evaluate(model, args)
return

print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)

if args.resume is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
else:
args.start_epoch = 0

torch.backends.cudnn.benchmark = True

model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)

train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)

sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
Expand All @@ -238,25 +277,15 @@ def main(args):
num_workers=args.num_workers,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)

logger = utils.MetricLogger()

done = False
for current_epoch in range(args.epochs):
for current_epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {current_epoch}")
if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(current_epoch)

sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
train_one_epoch(
model=model,
optimizer=optimizer,
Expand All @@ -269,13 +298,19 @@ def main(args):
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger)

if args.rank == 0:
# TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
if not args.distributed or args.rank == 0:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": current_epoch,
"args": args,
}
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")

if current_epoch % args.val_freq == 0 or done:
validate(model, args)
evaluate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
Expand Down Expand Up @@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")

return parser

Expand Down
7 changes: 6 additions & 1 deletion references/optical_flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,12 @@ def setup_ddp(args):
# if we're here, the script was called by run_with_submitit.py
args.local_rank = args.gpu
else:
raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.")
print("Not using distributed mode!")
args.distributed = False
args.world_size = 1
return

args.distributed = True

_redefine_print(is_main=(args.rank == 0))

Expand Down