@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):
60
60
61
61
62
62
@torch .no_grad ()
63
- def _validate (model , args , val_dataset , * , padder_mode , num_flow_updates = None , batch_size = None , header = None ):
63
+ def _evaluate (model , args , val_dataset , * , padder_mode , num_flow_updates = None , batch_size = None , header = None ):
64
64
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
65
65
66
66
We process as many samples as possible with ddp, and process the rest on a single worker.
67
67
"""
68
68
batch_size = batch_size or args .batch_size
69
+ device = torch .device (args .device )
69
70
70
71
model .eval ()
71
72
72
- sampler = torch .utils .data .distributed .DistributedSampler (val_dataset , shuffle = False , drop_last = True )
73
+ if args .distributed :
74
+ sampler = torch .utils .data .distributed .DistributedSampler (val_dataset , shuffle = False , drop_last = True )
75
+ else :
76
+ sampler = torch .utils .data .SequentialSampler (val_dataset )
77
+
73
78
val_loader = torch .utils .data .DataLoader (
74
79
val_dataset ,
75
80
sampler = sampler ,
@@ -88,7 +93,7 @@ def inner_loop(blob):
88
93
image1 , image2 , flow_gt = blob [:3 ]
89
94
valid_flow_mask = None if len (blob ) == 3 else blob [- 1 ]
90
95
91
- image1 , image2 = image1 .cuda ( ), image2 .cuda ( )
96
+ image1 , image2 = image1 .to ( device ), image2 .to ( device )
92
97
93
98
padder = utils .InputPadder (image1 .shape , mode = padder_mode )
94
99
image1 , image2 = padder .pad (image1 , image2 )
@@ -115,21 +120,22 @@ def inner_loop(blob):
115
120
inner_loop (blob )
116
121
num_processed_samples += blob [0 ].shape [0 ] # batch size
117
122
118
- num_processed_samples = utils .reduce_across_processes (num_processed_samples )
119
- print (
120
- f"Batch-processed { num_processed_samples } / { len (val_dataset )} samples. "
121
- "Going to process the remaining samples individually, if any."
122
- )
123
+ if args .distributed :
124
+ num_processed_samples = utils .reduce_across_processes (num_processed_samples )
125
+ print (
126
+ f"Batch-processed { num_processed_samples } / { len (val_dataset )} samples. "
127
+ "Going to process the remaining samples individually, if any."
128
+ )
129
+ if args .rank == 0 : # we only need to process the rest on a single worker
130
+ for i in range (num_processed_samples , len (val_dataset )):
131
+ inner_loop (val_dataset [i ])
123
132
124
- if args .rank == 0 : # we only need to process the rest on a single worker
125
- for i in range (num_processed_samples , len (val_dataset )):
126
- inner_loop (val_dataset [i ])
133
+ logger .synchronize_between_processes ()
127
134
128
- logger .synchronize_between_processes ()
129
135
print (header , logger )
130
136
131
137
132
- def validate (model , args ):
138
+ def evaluate (model , args ):
133
139
val_datasets = args .val_dataset or []
134
140
135
141
if args .prototype :
@@ -145,21 +151,21 @@ def validate(model, args):
145
151
if name == "kitti" :
146
152
# Kitti has different image sizes so we need to individually pad them, we can't batch.
147
153
# see comment in InputPadder
148
- if args .batch_size != 1 and args .rank == 0 :
154
+ if args .batch_size != 1 and ( not args .distributed or args . rank == 0 ) :
149
155
warnings .warn (
150
156
f"Batch-size={ args .batch_size } was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
151
157
)
152
158
153
159
val_dataset = KittiFlow (root = args .dataset_root , split = "train" , transforms = preprocessing )
154
- _validate (
160
+ _evaluate (
155
161
model , args , val_dataset , num_flow_updates = 24 , padder_mode = "kitti" , header = "Kitti val" , batch_size = 1
156
162
)
157
163
elif name == "sintel" :
158
164
for pass_name in ("clean" , "final" ):
159
165
val_dataset = Sintel (
160
166
root = args .dataset_root , split = "train" , pass_name = pass_name , transforms = preprocessing
161
167
)
162
- _validate (
168
+ _evaluate (
163
169
model ,
164
170
args ,
165
171
val_dataset ,
@@ -172,11 +178,12 @@ def validate(model, args):
172
178
173
179
174
180
def train_one_epoch (model , optimizer , scheduler , train_loader , logger , args ):
181
+ device = torch .device (args .device )
175
182
for data_blob in logger .log_every (train_loader ):
176
183
177
184
optimizer .zero_grad ()
178
185
179
- image1 , image2 , flow_gt , valid_flow_mask = (x .cuda ( ) for x in data_blob )
186
+ image1 , image2 , flow_gt , valid_flow_mask = (x .to ( device ) for x in data_blob )
180
187
flow_predictions = model (image1 , image2 , num_flow_updates = args .num_flow_updates )
181
188
182
189
loss = utils .sequence_loss (flow_predictions , flow_gt , valid_flow_mask , args .gamma )
@@ -200,36 +207,68 @@ def main(args):
200
207
raise ValueError ("The weights parameter works only in prototype mode. Please pass the --prototype argument." )
201
208
utils .setup_ddp (args )
202
209
210
+ if args .distributed and args .device == "cpu" :
211
+ raise ValueError ("The device must be cuda if we want to run in distributed mode using torchrun" )
212
+ device = torch .device (args .device )
213
+
203
214
if args .prototype :
204
215
model = prototype .models .optical_flow .__dict__ [args .model ](weights = args .weights )
205
216
else :
206
217
model = torchvision .models .optical_flow .__dict__ [args .model ](pretrained = args .pretrained )
207
218
208
- model = model .to (args .local_rank )
209
- model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
219
+ if args .distributed :
220
+ model = model .to (args .local_rank )
221
+ model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
222
+ model_without_ddp = model .module
223
+ else :
224
+ model .to (device )
225
+ model_without_ddp = model
210
226
211
227
if args .resume is not None :
212
- d = torch .load (args .resume , map_location = "cpu" )
213
- model .load_state_dict (d , strict = True )
228
+ checkpoint = torch .load (args .resume , map_location = "cpu" )
229
+ model_without_ddp .load_state_dict (checkpoint [ "model" ] )
214
230
215
231
if args .train_dataset is None :
216
232
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
217
233
torch .backends .cudnn .benchmark = False
218
234
torch .backends .cudnn .deterministic = True
219
- validate (model , args )
235
+ evaluate (model , args )
220
236
return
221
237
222
238
print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
223
239
240
+ train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
241
+
242
+ optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
243
+
244
+ scheduler = torch .optim .lr_scheduler .OneCycleLR (
245
+ optimizer = optimizer ,
246
+ max_lr = args .lr ,
247
+ epochs = args .epochs ,
248
+ steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
249
+ pct_start = 0.05 ,
250
+ cycle_momentum = False ,
251
+ anneal_strategy = "linear" ,
252
+ )
253
+
254
+ if args .resume is not None :
255
+ optimizer .load_state_dict (checkpoint ["optimizer" ])
256
+ scheduler .load_state_dict (checkpoint ["scheduler" ])
257
+ args .start_epoch = checkpoint ["epoch" ] + 1
258
+ else :
259
+ args .start_epoch = 0
260
+
224
261
torch .backends .cudnn .benchmark = True
225
262
226
263
model .train ()
227
264
if args .freeze_batch_norm :
228
265
utils .freeze_batch_norm (model .module )
229
266
230
- train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
267
+ if args .distributed :
268
+ sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True , drop_last = True )
269
+ else :
270
+ sampler = torch .utils .data .RandomSampler (train_dataset )
231
271
232
- sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True , drop_last = True )
233
272
train_loader = torch .utils .data .DataLoader (
234
273
train_dataset ,
235
274
sampler = sampler ,
@@ -238,25 +277,15 @@ def main(args):
238
277
num_workers = args .num_workers ,
239
278
)
240
279
241
- optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
242
-
243
- scheduler = torch .optim .lr_scheduler .OneCycleLR (
244
- optimizer = optimizer ,
245
- max_lr = args .lr ,
246
- epochs = args .epochs ,
247
- steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
248
- pct_start = 0.05 ,
249
- cycle_momentum = False ,
250
- anneal_strategy = "linear" ,
251
- )
252
-
253
280
logger = utils .MetricLogger ()
254
281
255
282
done = False
256
- for current_epoch in range (args .epochs ):
283
+ for current_epoch in range (args .start_epoch , args . epochs ):
257
284
print (f"EPOCH { current_epoch } " )
285
+ if args .distributed :
286
+ # needed on distributed mode, otherwise the data loading order would be the same for all epochs
287
+ sampler .set_epoch (current_epoch )
258
288
259
- sampler .set_epoch (current_epoch ) # needed, otherwise the data loading order would be the same for all epochs
260
289
train_one_epoch (
261
290
model = model ,
262
291
optimizer = optimizer ,
@@ -269,13 +298,19 @@ def main(args):
269
298
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
270
299
print (f"Epoch { current_epoch } done. " , logger )
271
300
272
- if args .rank == 0 :
273
- # TODO: Also save the optimizer and scheduler
274
- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
275
- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } .pth" )
301
+ if not args .distributed or args .rank == 0 :
302
+ checkpoint = {
303
+ "model" : model_without_ddp .state_dict (),
304
+ "optimizer" : optimizer .state_dict (),
305
+ "scheduler" : scheduler .state_dict (),
306
+ "epoch" : current_epoch ,
307
+ "args" : args ,
308
+ }
309
+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
310
+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } .pth" )
276
311
277
312
if current_epoch % args .val_freq == 0 or done :
278
- validate (model , args )
313
+ evaluate (model , args )
279
314
model .train ()
280
315
if args .freeze_batch_norm :
281
316
utils .freeze_batch_norm (model .module )
@@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
349
384
action = "store_true" ,
350
385
)
351
386
parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load." )
387
+ parser .add_argument ("--device" , default = "cuda" , type = str , help = "device (Use cuda or cpu, Default: cuda)" )
352
388
353
389
return parser
354
390
0 commit comments