10
10
import math
11
11
import os
12
12
import random
13
- from typing import Any , Dict , List , Optional , Sequence , Tuple , TYPE_CHECKING , Union
13
+ from typing import (
14
+ Any ,
15
+ Dict ,
16
+ Iterable ,
17
+ List ,
18
+ Optional ,
19
+ Sequence ,
20
+ Tuple ,
21
+ TYPE_CHECKING ,
22
+ Union ,
23
+ )
14
24
15
25
import numpy as np
16
26
import torch
@@ -180,7 +190,7 @@ def render_flyaround(
180
190
preds .update (net_input ) # merge everything into one big dict
181
191
182
192
# Render the predictions to images
183
- rendered_pred = _images_from_preds (preds )
193
+ rendered_pred = _images_from_preds (preds , extract_keys = visualize_preds_keys )
184
194
preds_total .append (rendered_pred )
185
195
186
196
# show the preds every 5% of the export iterations
@@ -223,17 +233,20 @@ def _load_whole_dataset(
223
233
return next (iter (load_all_dataloader ))
224
234
225
235
226
- def _images_from_preds (preds : Dict [ str , Any ]) -> Dict [ str , torch . Tensor ]:
227
- imout = {}
228
- for k in (
236
+ def _images_from_preds (
237
+ preds : Dict [ str , Any ],
238
+ extract_keys : Iterable [ str ] = (
229
239
"image_rgb" ,
230
240
"images_render" ,
231
241
"fg_probability" ,
232
242
"masks_render" ,
233
243
"depths_render" ,
234
244
"depth_map" ,
235
245
"_all_source_images" ,
236
- ):
246
+ ),
247
+ ) -> Dict [str , torch .Tensor ]:
248
+ imout = {}
249
+ for k in extract_keys :
237
250
if k == "_all_source_images" and "image_rgb" in preds :
238
251
src_ims = preds ["image_rgb" ][1 :].cpu ().detach ().clone ()
239
252
v = _stack_images (src_ims , None )[None ]
@@ -343,6 +356,9 @@ def _generate_prediction_videos(
343
356
# init a video writer for each predicted key
344
357
vws = {}
345
358
for k in predicted_keys :
359
+ if k not in preds [0 ]:
360
+ logger .warn (f"Cannot generate video for prediction key '{ k } '" )
361
+ continue
346
362
cache_dir = (
347
363
None
348
364
if video_frames_dir is None
@@ -355,13 +371,15 @@ def _generate_prediction_videos(
355
371
)
356
372
357
373
for rendered_pred in tqdm (preds ):
358
- for k in predicted_keys :
374
+ for k in vws :
359
375
vws [k ].write_frame (
360
376
rendered_pred [k ][0 ].clip (0.0 , 1.0 ).detach ().cpu ().numpy (),
361
377
resize = resize ,
362
378
)
363
379
364
380
for k in predicted_keys :
381
+ if k not in vws :
382
+ continue
365
383
vws [k ].get_video ()
366
384
logger .info (f"Generated { vws [k ].out_path } ." )
367
385
if viz is not None :
0 commit comments