Skip to content

Commit a56c949

Browse files
authored
Merge branch 'main' into bugfix/broken_import
2 parents dc71339 + 849d02b commit a56c949

File tree

5 files changed

+157
-33
lines changed

5 files changed

+157
-33
lines changed

references/optical_flow/README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Optical flow reference training scripts
2+
3+
This folder contains reference training scripts for optical flow.
4+
They serve as a log of how to train specific models, so as to provide baseline
5+
training and evaluation scripts to quickly bootstrap research.
6+
7+
8+
### RAFT Large
9+
10+
The RAFT large model was trained on Flying Chairs and then on Flying Things.
11+
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
12+
rest of the hyper-parameters are exactly the same as the original RAFT training
13+
recipe from https://github.com/princeton-vl/RAFT.
14+
15+
```
16+
torchrun --nproc_per_node 8 --nnodes 1 train.py \
17+
--dataset-root $dataset_root \
18+
--name $name_chairs \
19+
--model raft_large \
20+
--train-dataset chairs \
21+
--batch-size 2 \
22+
--lr 0.0004 \
23+
--weight-decay 0.0001 \
24+
--num-steps 100000 \
25+
--output-dir $chairs_dir
26+
```
27+
28+
```
29+
torchrun --nproc_per_node 8 --nnodes 1 train.py \
30+
--dataset-root $dataset_root \
31+
--name $name_things \
32+
--model raft_large \
33+
--train-dataset things \
34+
--batch-size 2 \
35+
--lr 0.000125 \
36+
--weight-decay 0.0001 \
37+
--num-steps 100000 \
38+
--freeze-batch-norm \
39+
--output-dir $things_dir\
40+
--resume $chairs_dir/$name_chairs.pth
41+
```
42+
43+
44+
### Evaluation
45+
46+
```
47+
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
48+
```
49+
50+
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
51+
final pass of Sintel. Results may vary slightly depending on the batch size and
52+
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:
53+
54+
```
55+
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
56+
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
57+
```

references/optical_flow/train.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
from pathlib import Path
44

55
import torch
6+
import torchvision.models.optical_flow
67
import utils
78
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
89
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
9-
from torchvision.models.optical_flow import raft_large, raft_small
10+
11+
try:
12+
from torchvision.prototype import models as PM
13+
from torchvision.prototype.models import optical_flow as PMOF
14+
except ImportError:
15+
PM = PMOF = None
1016

1117

1218
def get_train_dataset(stage, dataset_root):
@@ -125,6 +131,13 @@ def inner_loop(blob):
125131

126132
def validate(model, args):
127133
val_datasets = args.val_dataset or []
134+
135+
if args.weights:
136+
weights = PM.get_weight(args.weights)
137+
preprocessing = weights.transforms()
138+
else:
139+
preprocessing = OpticalFlowPresetEval()
140+
128141
for name in val_datasets:
129142
if name == "kitti":
130143
# Kitti has different image sizes so we need to individually pad them, we can't batch.
@@ -134,14 +147,14 @@ def validate(model, args):
134147
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
135148
)
136149

137-
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
150+
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
138151
_validate(
139152
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
140153
)
141154
elif name == "sintel":
142155
for pass_name in ("clean", "final"):
143156
val_dataset = Sintel(
144-
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
157+
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
145158
)
146159
_validate(
147160
model,
@@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
187200
def main(args):
188201
utils.setup_ddp(args)
189202

190-
model = raft_small() if args.small else raft_large()
203+
if args.weights:
204+
model = PMOF.__dict__[args.model](weights=args.weights)
205+
else:
206+
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
207+
191208
model = model.to(args.local_rank)
192209
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
193210

@@ -306,7 +323,12 @@ def get_args_parser(add_help=True):
306323
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
307324
)
308325

309-
parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
326+
parser.add_argument(
327+
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
328+
)
329+
# TODO: resume, pretrained, and weights should be in an exclusive arg group
330+
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
331+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
310332

311333
parser.add_argument(
312334
"--num_flow_updates",

test/test_prototype_models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def test_naming_conventions(model_fn):
9191
+ TM.get_models_from_module(models.detection)
9292
+ TM.get_models_from_module(models.quantization)
9393
+ TM.get_models_from_module(models.segmentation)
94-
+ TM.get_models_from_module(models.video),
94+
+ TM.get_models_from_module(models.video)
95+
+ TM.get_models_from_module(models.optical_flow),
9596
)
9697
def test_schema_meta_validation(model_fn):
9798
classification_fields = ["size", "categories", "acc@1", "acc@5"]
@@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn):
102103
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
103104
"segmentation": ["categories", "mIoU", "acc"],
104105
"video": classification_fields,
106+
"optical_flow": [],
105107
}
106108
module_name = model_fn.__module__.split(".")[-2]
107109
fields = set(defaults["all"] + defaults[module_name])
@@ -201,13 +203,18 @@ def test_old_vs_new_factory(model_fn, dev):
201203
if module_name == "detection":
202204
x = [x]
203205

206+
if module_name == "optical_flow":
207+
args = [x, x] # RAFT model requires img1, img2 as input
208+
else:
209+
args = [x]
210+
204211
# compare with new model builder parameterized in the old fashion way
205212
try:
206213
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
207214
model_new = _build_model(model_fn, **kwargs).to(device=dev)
208215
except ModuleNotFoundError:
209216
pytest.skip(f"Model '{model_name}' not available in both modules.")
210-
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
217+
torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False)
211218

212219

213220
def test_smoke():

torchvision/models/optical_flow/raft.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.nn.modules.instancenorm import InstanceNorm2d
99
from torchvision.ops import ConvNormActivation
1010

11+
from ..._internally_replaced_utils import load_state_dict_from_url
1112
from ...utils import _log_api_usage_once
1213
from ._utils import grid_sample, make_coords_grid, upsample_flow
1314

@@ -19,6 +20,9 @@
1920
)
2021

2122

23+
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}
24+
25+
2226
class ResidualBlock(nn.Module):
2327
"""Slightly modified Residual block with extra relu and biases."""
2428

@@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
474478
hidden_state = torch.tanh(hidden_state)
475479
context = F.relu(context)
476480

477-
coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
478-
coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
481+
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
482+
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
479483

480484
flow_predictions = []
481485
for _ in range(num_flow_updates):
@@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
496500

497501
def _raft(
498502
*,
503+
arch=None,
504+
pretrained=False,
505+
progress=False,
499506
# Feature encoder
500507
feature_encoder_layers,
501508
feature_encoder_block,
@@ -560,14 +567,19 @@ def _raft(
560567
multiplier=0.25, # See comment in MaskPredictor about this
561568
)
562569

563-
return RAFT(
570+
model = RAFT(
564571
feature_encoder=feature_encoder,
565572
context_encoder=context_encoder,
566573
corr_block=corr_block,
567574
update_block=update_block,
568575
mask_predictor=mask_predictor,
569576
**kwargs, # not really needed, all params should be consumed by now
570577
)
578+
if pretrained:
579+
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
580+
model.load_state_dict(state_dict)
581+
582+
return model
571583

572584

573585
def raft_large(*, pretrained=False, progress=True, **kwargs):
@@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
584596
nn.Module: The model.
585597
"""
586598

587-
if pretrained:
588-
raise ValueError("No checkpoint is available for raft_large")
589-
590599
return _raft(
600+
arch="raft_large",
601+
pretrained=pretrained,
602+
progress=progress,
591603
# Feature encoder
592604
feature_encoder_layers=(64, 64, 96, 128, 256),
593605
feature_encoder_block=ResidualBlock,
@@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
629641
nn.Module: The model.
630642
631643
"""
632-
633644
if pretrained:
634645
raise ValueError("No checkpoint is available for raft_small")
635646

636647
return _raft(
648+
arch="raft_small",
649+
pretrained=pretrained,
650+
progress=progress,
637651
# Feature encoder
638652
feature_encoder_layers=(32, 32, 64, 96, 128),
639653
feature_encoder_block=BottleneckBlock,

torchvision/prototype/models/optical_flow/raft.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from torch.nn.modules.instancenorm import InstanceNorm2d
55
from torchvision.models.optical_flow import RAFT
66
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
7-
8-
# from torchvision.prototype.transforms import RaftEval
7+
from torchvision.prototype.transforms import RaftEval
8+
from torchvision.transforms.functional import InterpolationMode
99

1010
from .._api import WeightsEnum
11-
12-
# from .._api import Weights
11+
from .._api import Weights
1312
from .._utils import handle_legacy_interface
1413

1514

@@ -22,17 +21,33 @@
2221
)
2322

2423

24+
_COMMON_META = {"interpolation": InterpolationMode.BILINEAR}
25+
26+
2527
class Raft_Large_Weights(WeightsEnum):
26-
pass
27-
# C_T_V1 = Weights(
28-
# # Chairs + Things
29-
# url="",
30-
# transforms=RaftEval,
31-
# meta={
32-
# "recipe": "",
33-
# "epe": -1234,
34-
# },
35-
# )
28+
C_T_V1 = Weights(
29+
# Chairs + Things, ported from original paper repo (raft-things.pth)
30+
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
31+
transforms=RaftEval,
32+
meta={
33+
**_COMMON_META,
34+
"recipe": "https://github.com/princeton-vl/RAFT",
35+
"sintel_train_cleanpass_epe": 1.4411,
36+
"sintel_train_finalpass_epe": 2.7894,
37+
},
38+
)
39+
40+
C_T_V2 = Weights(
41+
# Chairs + Things
42+
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
43+
transforms=RaftEval,
44+
meta={
45+
**_COMMON_META,
46+
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
47+
"sintel_train_cleanpass_epe": 1.3822,
48+
"sintel_train_finalpass_epe": 2.7161,
49+
},
50+
)
3651

3752
# C_T_SKHT_V1 = Weights(
3853
# # Chairs + Things + Sintel fine-tuning, i.e.:
@@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum):
5974
# },
6075
# )
6176

62-
# default = C_T_V1
77+
default = C_T_V2
6378

6479

6580
class Raft_Small_Weights(WeightsEnum):
@@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum):
7590
# default = C_T_V1
7691

7792

78-
@handle_legacy_interface(weights=("pretrained", None))
93+
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
7994
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
8095
"""RAFT model from
8196
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
8297
8398
Args:
84-
weights(Raft_Large_weights, optinal): TODO not implemented yet
99+
weights(Raft_Large_weights, optional): pretrained weights to use.
85100
progress (bool): If True, displays a progress bar of the download to stderr
86101
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
87102
to override any default.
@@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
92107

93108
weights = Raft_Large_Weights.verify(weights)
94109

95-
return _raft(
110+
model = _raft(
96111
# Feature encoder
97112
feature_encoder_layers=(64, 64, 96, 128, 256),
98113
feature_encoder_block=ResidualBlock,
@@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
119134
**kwargs,
120135
)
121136

137+
if weights is not None:
138+
model.load_state_dict(weights.get_state_dict(progress=progress))
139+
140+
return model
141+
122142

123143
@handle_legacy_interface(weights=("pretrained", None))
124144
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
@@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
138158

139159
weights = Raft_Small_Weights.verify(weights)
140160

141-
return _raft(
161+
model = _raft(
142162
# Feature encoder
143163
feature_encoder_layers=(32, 32, 64, 96, 128),
144164
feature_encoder_block=BottleneckBlock,
@@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
164184
use_mask_predictor=False,
165185
**kwargs,
166186
)
187+
188+
if weights is not None:
189+
model.load_state_dict(weights.get_state_dict(progress=progress))
190+
return model

0 commit comments

Comments
 (0)