Skip to content

Commit 17ac681

Browse files
authored
Merge branch 'main' into patch-1
2 parents d23d118 + 849d02b commit 17ac681

34 files changed

+720
-182
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ jobs:
311311
descr: Install Python type check utilities
312312
- run:
313313
name: Check Python types statically
314-
command: mypy --config-file mypy.ini
314+
command: mypy --install-types --non-interactive --config-file mypy.ini
315315

316316
unittest_torchhub:
317317
docker:

packaging/torchvision/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ requirements:
2424
run:
2525
- python
2626
- defaults::numpy >=1.11
27+
- requests
2728
- libpng
2829
- ffmpeg >=4.2 # [not win]
2930
- jpeg

references/optical_flow/README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Optical flow reference training scripts
2+
3+
This folder contains reference training scripts for optical flow.
4+
They serve as a log of how to train specific models, so as to provide baseline
5+
training and evaluation scripts to quickly bootstrap research.
6+
7+
8+
### RAFT Large
9+
10+
The RAFT large model was trained on Flying Chairs and then on Flying Things.
11+
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
12+
rest of the hyper-parameters are exactly the same as the original RAFT training
13+
recipe from https://github.com/princeton-vl/RAFT.
14+
15+
```
16+
torchrun --nproc_per_node 8 --nnodes 1 train.py \
17+
--dataset-root $dataset_root \
18+
--name $name_chairs \
19+
--model raft_large \
20+
--train-dataset chairs \
21+
--batch-size 2 \
22+
--lr 0.0004 \
23+
--weight-decay 0.0001 \
24+
--num-steps 100000 \
25+
--output-dir $chairs_dir
26+
```
27+
28+
```
29+
torchrun --nproc_per_node 8 --nnodes 1 train.py \
30+
--dataset-root $dataset_root \
31+
--name $name_things \
32+
--model raft_large \
33+
--train-dataset things \
34+
--batch-size 2 \
35+
--lr 0.000125 \
36+
--weight-decay 0.0001 \
37+
--num-steps 100000 \
38+
--freeze-batch-norm \
39+
--output-dir $things_dir\
40+
--resume $chairs_dir/$name_chairs.pth
41+
```
42+
43+
44+
### Evaluation
45+
46+
```
47+
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
48+
```
49+
50+
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
51+
final pass of Sintel. Results may vary slightly depending on the batch size and
52+
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:
53+
54+
```
55+
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
56+
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
57+
```

references/optical_flow/train.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
from pathlib import Path
44

55
import torch
6+
import torchvision.models.optical_flow
67
import utils
78
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
89
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
9-
from torchvision.models.optical_flow import raft_large, raft_small
10+
11+
try:
12+
from torchvision.prototype import models as PM
13+
from torchvision.prototype.models import optical_flow as PMOF
14+
except ImportError:
15+
PM = PMOF = None
1016

1117

1218
def get_train_dataset(stage, dataset_root):
@@ -125,6 +131,13 @@ def inner_loop(blob):
125131

126132
def validate(model, args):
127133
val_datasets = args.val_dataset or []
134+
135+
if args.weights:
136+
weights = PM.get_weight(args.weights)
137+
preprocessing = weights.transforms()
138+
else:
139+
preprocessing = OpticalFlowPresetEval()
140+
128141
for name in val_datasets:
129142
if name == "kitti":
130143
# Kitti has different image sizes so we need to individually pad them, we can't batch.
@@ -134,14 +147,14 @@ def validate(model, args):
134147
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
135148
)
136149

137-
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
150+
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
138151
_validate(
139152
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
140153
)
141154
elif name == "sintel":
142155
for pass_name in ("clean", "final"):
143156
val_dataset = Sintel(
144-
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
157+
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
145158
)
146159
_validate(
147160
model,
@@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
187200
def main(args):
188201
utils.setup_ddp(args)
189202

190-
model = raft_small() if args.small else raft_large()
203+
if args.weights:
204+
model = PMOF.__dict__[args.model](weights=args.weights)
205+
else:
206+
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
207+
191208
model = model.to(args.local_rank)
192209
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
193210

@@ -306,7 +323,12 @@ def get_args_parser(add_help=True):
306323
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
307324
)
308325

309-
parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
326+
parser.add_argument(
327+
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
328+
)
329+
# TODO: resume, pretrained, and weights should be in an exclusive arg group
330+
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
331+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
310332

311333
parser.add_argument(
312334
"--num_flow_updates",

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def write_version_file():
5959

6060
requirements = [
6161
"numpy",
62+
"requests",
6263
pytorch_dep,
6364
]
6465

test/builtin_dataset_mocks.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@
2929
DEFAULT_TEST_DECODER = object()
3030

3131

32+
class TestResource(datasets.utils.OnlineResource):
33+
def __init__(self, *, dataset_name, dataset_config, **kwargs):
34+
super().__init__(**kwargs)
35+
self.dataset_name = dataset_name
36+
self.dataset_config = dataset_config
37+
38+
def _download(self, _):
39+
raise pytest.UsageError(
40+
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
41+
f"but this file does not exist."
42+
)
43+
44+
3245
class DatasetMocks:
3346
def __init__(self):
3447
self._mock_data_fns = {}
@@ -72,7 +85,7 @@ def _parse_mock_info(self, mock_info, *, name):
7285
)
7386
return mock_info
7487

75-
def _get(self, dataset, config):
88+
def _get(self, dataset, config, root):
7689
name = dataset.info.name
7790
resources_and_mock_info = self._cache.get((name, config))
7891
if resources_and_mock_info:
@@ -87,20 +100,12 @@ def _get(self, dataset, config):
87100
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
88101
)
89102

90-
root = self._tmp_home / name
91-
root.mkdir(exist_ok=True)
103+
mock_resources = [
104+
TestResource(dataset_name=name, dataset_config=config, file_name=resource.file_name)
105+
for resource in dataset.resources(config)
106+
]
92107
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name)
93108

94-
mock_resources = []
95-
for resource in dataset.resources(config):
96-
path = root / resource.file_name
97-
if not path.exists() and path.is_file():
98-
raise pytest.UsageError(
99-
f"Dataset '{name}' requires the file {path.name} for {config}, but this file does not exist."
100-
)
101-
102-
mock_resources.append(datasets.utils.LocalResource(path))
103-
104109
self._cache[(name, config)] = mock_resources, mock_info
105110
return mock_resources, mock_info
106111

@@ -109,9 +114,13 @@ def load(
109114
) -> Tuple[IterDataPipe, Dict[str, Any]]:
110115
dataset = find(name)
111116
config = dataset.info.make_config(split=split, **options)
112-
resources, mock_info = self._get(dataset, config)
117+
118+
root = self._tmp_home / name
119+
root.mkdir(exist_ok=True)
120+
resources, mock_info = self._get(dataset, config, root)
121+
113122
datapipe = dataset._make_datapipe(
114-
[resource.to_datapipe() for resource in resources],
123+
[resource.load(root) for resource in resources],
115124
config=config,
116125
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
117126
)

test/test_prototype_datasets_api.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ def test_default_config(self):
211211
pytest.param(make_minimal_dataset_info().default_config, None, id="default"),
212212
],
213213
)
214-
def test_to_datapipe_config(self, config, kwarg):
214+
def test_load_config(self, config, kwarg):
215215
dataset = self.DatasetMock()
216216

217-
dataset.to_datapipe("", config=kwarg)
217+
dataset.load("", config=kwarg)
218218

219219
dataset.resources.assert_called_with(config)
220220

@@ -225,18 +225,19 @@ def test_missing_dependencies(self):
225225
dependency = "fake_dependency"
226226
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
227227
with pytest.raises(ModuleNotFoundError, match=dependency):
228-
dataset.to_datapipe("root")
228+
dataset.load("root")
229229

230230
def test_resources(self, mocker):
231-
resource_mock = mocker.Mock(spec=["to_datapipe"])
231+
resource_mock = mocker.Mock(spec=["load"])
232232
sentinel = object()
233-
resource_mock.to_datapipe.return_value = sentinel
233+
resource_mock.load.return_value = sentinel
234234
dataset = self.DatasetMock(resources=[resource_mock])
235235

236236
root = "root"
237-
dataset.to_datapipe(root)
237+
dataset.load(root)
238238

239-
resource_mock.to_datapipe.assert_called_with(root)
239+
(call_args, _) = resource_mock.load.call_args
240+
assert call_args[0] == root
240241

241242
(call_args, _) = dataset._make_datapipe.call_args
242243
assert call_args[0][0] is sentinel
@@ -245,7 +246,7 @@ def test_decoder(self):
245246
dataset = self.DatasetMock()
246247

247248
sentinel = object()
248-
dataset.to_datapipe("", decoder=sentinel)
249+
dataset.load("", decoder=sentinel)
249250

250251
(_, call_kwargs) = dataset._make_datapipe.call_args
251252
assert call_kwargs["decoder"] is sentinel

0 commit comments

Comments
 (0)