@@ -34,6 +34,8 @@ class Raft_Large_Weights(WeightsEnum):
34
34
"recipe" : "https://github.com/princeton-vl/RAFT" ,
35
35
"sintel_train_cleanpass_epe" : 1.4411 ,
36
36
"sintel_train_finalpass_epe" : 2.7894 ,
37
+ "kitti_train_per_image_epe" : 5.0172 ,
38
+ "kitti_train_f1-all" : 17.4506 ,
37
39
},
38
40
)
39
41
@@ -46,6 +48,8 @@ class Raft_Large_Weights(WeightsEnum):
46
48
"recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
47
49
"sintel_train_cleanpass_epe" : 1.3822 ,
48
50
"sintel_train_finalpass_epe" : 2.7161 ,
51
+ "kitti_train_per_image_epe" : 4.5118 ,
52
+ "kitti_train_f1-all" : 16.0679 ,
49
53
},
50
54
)
51
55
@@ -87,10 +91,25 @@ class Raft_Small_Weights(WeightsEnum):
87
91
"recipe" : "https://github.com/princeton-vl/RAFT" ,
88
92
"sintel_train_cleanpass_epe" : 2.1231 ,
89
93
"sintel_train_finalpass_epe" : 3.2790 ,
94
+ "kitti_train_per_image_epe" : 7.6557 ,
95
+ "kitti_train_f1-all" : 25.2801 ,
96
+ },
97
+ )
98
+ C_T_V2 = Weights (
99
+ # Chairs + Things
100
+ url = "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth" ,
101
+ transforms = RaftEval ,
102
+ meta = {
103
+ ** _COMMON_META ,
104
+ "recipe" : "https://github.com/princeton-vl/RAFT" ,
105
+ "sintel_train_cleanpass_epe" : 1.9901 ,
106
+ "sintel_train_finalpass_epe" : 3.2831 ,
107
+ "kitti_train_per_image_epe" : 7.5978 ,
108
+ "kitti_train_f1-all" : 25.2369 ,
90
109
},
91
110
)
92
111
93
- default = C_T_V1 # TODO: Change to V2 once we upload our own weights
112
+ default = C_T_V2
94
113
95
114
96
115
@handle_legacy_interface (weights = ("pretrained" , Raft_Large_Weights .C_T_V2 ))
@@ -143,14 +162,13 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
143
162
return model
144
163
145
164
146
- # TODO: change to V2 once we upload our own weights
147
- @handle_legacy_interface (weights = ("pretrained" , Raft_Small_Weights .C_T_V1 ))
165
+ @handle_legacy_interface (weights = ("pretrained" , Raft_Small_Weights .C_T_V2 ))
148
166
def raft_small (* , weights : Optional [Raft_Small_Weights ] = None , progress = True , ** kwargs ):
149
167
"""RAFT "small" model from
150
168
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
151
169
152
170
Args:
153
- weights(Raft_Small_weights, optinal ): TODO not implemented yet
171
+ weights(Raft_Small_weights, optional ): pretrained weights to use.
154
172
progress (bool): If True, displays a progress bar of the download to stderr
155
173
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
156
174
to override any default.
0 commit comments