Skip to content

Commit 3c07750

Browse files
committed
Add our weights for raft_small on C+T
1 parent 48e2f23 commit 3c07750

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

references/optical_flow/README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,18 @@ torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-siz
4848
```
4949

5050
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
51-
final pass of Sintel. Results may vary slightly depending on the batch size and
52-
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:
51+
final pass of Sintel-train. Results may vary slightly depending on the batch
52+
size and the number of GPUs. For the most accurate resuts use 1 GPU and
53+
`--batch-size 1`:
5354

5455
```
5556
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
5657
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
5758
```
59+
60+
You can also evaluate on Kitti train:
61+
62+
```
63+
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
64+
Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679
65+
```

torchvision/models/optical_flow/raft.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323
_MODELS_URLS = {
2424
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
25-
# TODO: change to V2 once we upload our own weights
26-
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
25+
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
2726
}
2827

2928

@@ -591,7 +590,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
591590
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
592591
593592
Args:
594-
pretrained (bool): TODO not implemented yet
593+
pretrained (bool): Whether to use pretrained weights.
595594
progress (bool): If True, displays a progress bar of the download to stderr
596595
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
597596
to override any default.
@@ -636,7 +635,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
636635
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
637636
638637
Args:
639-
pretrained (bool): TODO not implemented yet
638+
pretrained (bool): Whether to use pretrained weights.
640639
progress (bool): If True, displays a progress bar of the download to stderr
641640
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
642641
to override any default.

torchvision/prototype/models/optical_flow/raft.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class Raft_Large_Weights(WeightsEnum):
3434
"recipe": "https://github.com/princeton-vl/RAFT",
3535
"sintel_train_cleanpass_epe": 1.4411,
3636
"sintel_train_finalpass_epe": 2.7894,
37+
"kitti_train_per_image_epe": 5.0172,
38+
"kitti_train_f1-all": 17.4506,
3739
},
3840
)
3941

@@ -46,6 +48,8 @@ class Raft_Large_Weights(WeightsEnum):
4648
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
4749
"sintel_train_cleanpass_epe": 1.3822,
4850
"sintel_train_finalpass_epe": 2.7161,
51+
"kitti_train_per_image_epe": 4.5118,
52+
"kitti_train_f1-all": 16.0679,
4953
},
5054
)
5155

@@ -87,10 +91,25 @@ class Raft_Small_Weights(WeightsEnum):
8791
"recipe": "https://github.com/princeton-vl/RAFT",
8892
"sintel_train_cleanpass_epe": 2.1231,
8993
"sintel_train_finalpass_epe": 3.2790,
94+
"kitti_train_per_image_epe": 7.6557,
95+
"kitti_train_f1-all": 25.2801,
96+
},
97+
)
98+
C_T_V2 = Weights(
99+
# Chairs + Things
100+
url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
101+
transforms=RaftEval,
102+
meta={
103+
**_COMMON_META,
104+
"recipe": "https://github.com/princeton-vl/RAFT",
105+
"sintel_train_cleanpass_epe": 1.9901,
106+
"sintel_train_finalpass_epe": 3.2831,
107+
"kitti_train_per_image_epe": 7.5978,
108+
"kitti_train_f1-all": 25.2369,
90109
},
91110
)
92111

93-
default = C_T_V1 # TODO: Change to V2 once we upload our own weights
112+
default = C_T_V2
94113

95114

96115
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
@@ -143,14 +162,13 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
143162
return model
144163

145164

146-
# TODO: change to V2 once we upload our own weights
147-
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V1))
165+
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
148166
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
149167
"""RAFT "small" model from
150168
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
151169
152170
Args:
153-
weights(Raft_Small_weights, optinal): TODO not implemented yet
171+
weights(Raft_Small_weights, optional): pretrained weights to use.
154172
progress (bool): If True, displays a progress bar of the download to stderr
155173
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
156174
to override any default.

0 commit comments

Comments
 (0)