Skip to content

Add pretrained weights on Chairs and Things for raft_large #5060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions references/optical_flow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Optical flow reference training scripts

This folder contains reference training scripts for optical flow.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.


### RAFT Large

The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT.

```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_chairs \
--model raft_large \
--train-dataset chairs \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--num-steps 100000 \
--output-dir $chairs_dir
```

```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model raft_large \
--train-dataset things \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--num-steps 100000 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
```


### Evaluation

```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
```

This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
final pass of Sintel. Results may vary slightly depending on the batch size and
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:

```
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
```
32 changes: 27 additions & 5 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
from pathlib import Path

import torch
import torchvision.models.optical_flow
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
from torchvision.models.optical_flow import raft_large, raft_small

try:
from torchvision.prototype import models as PM
from torchvision.prototype.models import optical_flow as PMOF
except ImportError:
PM = PMOF = None


def get_train_dataset(stage, dataset_root):
Expand Down Expand Up @@ -125,6 +131,13 @@ def inner_loop(blob):

def validate(model, args):
val_datasets = args.val_dataset or []

if args.weights:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = OpticalFlowPresetEval()

for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
Expand All @@ -134,14 +147,14 @@ def validate(model, args):
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)

val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_validate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_validate(
model,
Expand Down Expand Up @@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
def main(args):
utils.setup_ddp(args)

model = raft_small() if args.small else raft_large()
if args.weights:
model = PMOF.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)

model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

Expand Down Expand Up @@ -306,7 +323,12 @@ def get_args_parser(add_help=True):
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)

parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
parser.add_argument(
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
)
# TODO: resume, pretrained, and weights should be in an exclusive arg group
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")

parser.add_argument(
"--num_flow_updates",
Expand Down
11 changes: 9 additions & 2 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_naming_conventions(model_fn):
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5"]
Expand All @@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn):
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
"segmentation": ["categories", "mIoU", "acc"],
"video": classification_fields,
"optical_flow": [],
}
module_name = model_fn.__module__.split(".")[-2]
fields = set(defaults["all"] + defaults[module_name])
Expand Down Expand Up @@ -201,13 +203,18 @@ def test_old_vs_new_factory(model_fn, dev):
if module_name == "detection":
x = [x]

if module_name == "optical_flow":
args = [x, x] # RAFT model requires img1, img2 as input
else:
args = [x]

# compare with new model builder parameterized in the old fashion way
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False)


def test_smoke():
Expand Down
28 changes: 21 additions & 7 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.ops import ConvNormActivation

from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once
from ._utils import grid_sample, make_coords_grid, upsample_flow

Expand All @@ -19,6 +20,9 @@
)


_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once PR is merged I will upload this to manifold

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: all current models use model_urls



class ResidualBlock(nn.Module):
"""Slightly modified Residual block with extra relu and biases."""

Expand Down Expand Up @@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
hidden_state = torch.tanh(hidden_state)
context = F.relu(context)

coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)

flow_predictions = []
for _ in range(num_flow_updates):
Expand All @@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12):

def _raft(
*,
arch=None,
pretrained=False,
progress=False,
# Feature encoder
feature_encoder_layers,
feature_encoder_block,
Expand Down Expand Up @@ -560,14 +567,19 @@ def _raft(
multiplier=0.25, # See comment in MaskPredictor about this
)

return RAFT(
model = RAFT(
feature_encoder=feature_encoder,
context_encoder=context_encoder,
corr_block=corr_block,
update_block=update_block,
mask_predictor=mask_predictor,
**kwargs, # not really needed, all params should be consumed by now
)
if pretrained:
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)

return model


def raft_large(*, pretrained=False, progress=True, **kwargs):
Expand All @@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.
"""

if pretrained:
raise ValueError("No checkpoint is available for raft_large")

return _raft(
arch="raft_large",
pretrained=pretrained,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
Expand Down Expand Up @@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.

"""

if pretrained:
raise ValueError("No checkpoint is available for raft_small")

return _raft(
arch="raft_small",
pretrained=pretrained,
progress=progress,
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
Expand Down
62 changes: 43 additions & 19 deletions torchvision/prototype/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock

# from torchvision.prototype.transforms import RaftEval
from torchvision.prototype.transforms import RaftEval
from torchvision.transforms.functional import InterpolationMode

from .._api import WeightsEnum

# from .._api import Weights
from .._api import Weights
from .._utils import handle_legacy_interface


Expand All @@ -22,17 +21,33 @@
)


_COMMON_META = {"interpolation": InterpolationMode.BILINEAR}


class Raft_Large_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# # Chairs + Things
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 1.4411,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to rename one of them as the default epe? This will allow you to add the metric in the schema of meta-data for optical flow models. It's also worth considering introducing a dictionary entry in the meta-data that holds other epe values for for different datasets etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to rename one of them as the default epe?

Unfortunately no, because the rest of the weights will be trained on sintel, so reporting the epe on the trainset would not be relevant

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to have a dict or something else to properly keep track of the other metrics though - ultimately I think it would make sense to also have 1px, 3px etc. I think we'll have a better idea of what it should look like once the rest of the weights are available

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, no strong opinions. You could dump all the metrics in an epe dictionary. Then you would be able to include this on the schema. Up to you.

"sintel_train_finalpass_epe": 2.7894,
},
)

C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.3822,
"sintel_train_finalpass_epe": 2.7161,
},
)

# C_T_SKHT_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning, i.e.:
Expand All @@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum):
# },
# )

# default = C_T_V1
default = C_T_V2


class Raft_Small_Weights(WeightsEnum):
Expand All @@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum):
# default = C_T_V1


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
weights(Raft_Large_weights, optinal): TODO not implemented yet
weights(Raft_Large_weights, optional): pretrained weights to use.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Expand All @@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *

weights = Raft_Large_Weights.verify(weights)

return _raft(
model = _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
Expand All @@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
**kwargs,
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


@handle_legacy_interface(weights=("pretrained", None))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
Expand All @@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *

weights = Raft_Small_Weights.verify(weights)

return _raft(
model = _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
Expand All @@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
use_mask_predictor=False,
**kwargs,
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model