Skip to content

Commit 48e2f23

Browse files
authored
Add pretrained weights for raft_small from original paper (#5070)
1 parent 4cacf5a commit 48e2f23

File tree

2 files changed

+20
-14
lines changed
  • torchvision

2 files changed

+20
-14
lines changed

torchvision/models/optical_flow/raft.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
)
2121

2222

23-
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}
23+
_MODELS_URLS = {
24+
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
25+
# TODO: change to V2 once we upload our own weights
26+
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
27+
}
2428

2529

2630
class ResidualBlock(nn.Module):
@@ -641,8 +645,6 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
641645
nn.Module: The model.
642646
643647
"""
644-
if pretrained:
645-
raise ValueError("No checkpoint is available for raft_small")
646648

647649
return _raft(
648650
arch="raft_small",

torchvision/prototype/models/optical_flow/raft.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,19 @@ class Raft_Large_Weights(WeightsEnum):
7878

7979

8080
class Raft_Small_Weights(WeightsEnum):
81-
pass
82-
# C_T_V1 = Weights(
83-
# url="", # TODO
84-
# transforms=RaftEval,
85-
# meta={
86-
# "recipe": "",
87-
# "epe": -1234,
88-
# },
89-
# )
90-
# default = C_T_V1
81+
C_T_V1 = Weights(
82+
# Chairs + Things, ported from original paper repo (raft-small.pth)
83+
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
84+
transforms=RaftEval,
85+
meta={
86+
**_COMMON_META,
87+
"recipe": "https://github.com/princeton-vl/RAFT",
88+
"sintel_train_cleanpass_epe": 2.1231,
89+
"sintel_train_finalpass_epe": 3.2790,
90+
},
91+
)
92+
93+
default = C_T_V1 # TODO: Change to V2 once we upload our own weights
9194

9295

9396
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
@@ -140,7 +143,8 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
140143
return model
141144

142145

143-
@handle_legacy_interface(weights=("pretrained", None))
146+
# TODO: change to V2 once we upload our own weights
147+
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V1))
144148
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
145149
"""RAFT "small" model from
146150
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

0 commit comments

Comments
 (0)