diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index ff851b6382e..dda68d73721 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -20,7 +20,11 @@ ) -_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"} +_MODELS_URLS = { + "raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + # TODO: change to V2 once we upload our own weights + "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", +} class ResidualBlock(nn.Module): @@ -641,8 +645,6 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): nn.Module: The model. """ - if pretrained: - raise ValueError("No checkpoint is available for raft_small") return _raft( arch="raft_small", diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 4fc7e962864..b1b5fcbe911 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -78,16 +78,19 @@ class Raft_Large_Weights(WeightsEnum): class Raft_Small_Weights(WeightsEnum): - pass - # C_T_V1 = Weights( - # url="", # TODO - # transforms=RaftEval, - # meta={ - # "recipe": "", - # "epe": -1234, - # }, - # ) - # default = C_T_V1 + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-small.pth) + url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + transforms=RaftEval, + meta={ + **_COMMON_META, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 2.1231, + "sintel_train_finalpass_epe": 3.2790, + }, + ) + + default = C_T_V1 # TODO: Change to V2 once we upload our own weights @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) @@ -140,7 +143,8 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * return model -@handle_legacy_interface(weights=("pretrained", None)) +# TODO: change to V2 once we upload our own weights +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V1)) def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): """RAFT "small" model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_.