Skip to content

Commit 94f321f

Browse files
davnov134facebook-github-bot
authored andcommitted
render_flyaround bugfix
Summary: Fixes a bug which would crash render_flyaround anytime visualize_preds_keys is adjusted Reviewed By: shapovalov Differential Revision: D41124462 fbshipit-source-id: 127045a91a055909f8bd56c8af81afac02c00f60
1 parent 35f8cb9 commit 94f321f

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

pytorch3d/implicitron/models/visualization/render_flyaround.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@
1010
import math
1111
import os
1212
import random
13-
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
13+
from typing import (
14+
Any,
15+
Dict,
16+
Iterable,
17+
List,
18+
Optional,
19+
Sequence,
20+
Tuple,
21+
TYPE_CHECKING,
22+
Union,
23+
)
1424

1525
import numpy as np
1626
import torch
@@ -180,7 +190,7 @@ def render_flyaround(
180190
preds.update(net_input) # merge everything into one big dict
181191

182192
# Render the predictions to images
183-
rendered_pred = _images_from_preds(preds)
193+
rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys)
184194
preds_total.append(rendered_pred)
185195

186196
# show the preds every 5% of the export iterations
@@ -223,17 +233,20 @@ def _load_whole_dataset(
223233
return next(iter(load_all_dataloader))
224234

225235

226-
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
227-
imout = {}
228-
for k in (
236+
def _images_from_preds(
237+
preds: Dict[str, Any],
238+
extract_keys: Iterable[str] = (
229239
"image_rgb",
230240
"images_render",
231241
"fg_probability",
232242
"masks_render",
233243
"depths_render",
234244
"depth_map",
235245
"_all_source_images",
236-
):
246+
),
247+
) -> Dict[str, torch.Tensor]:
248+
imout = {}
249+
for k in extract_keys:
237250
if k == "_all_source_images" and "image_rgb" in preds:
238251
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
239252
v = _stack_images(src_ims, None)[None]
@@ -343,6 +356,9 @@ def _generate_prediction_videos(
343356
# init a video writer for each predicted key
344357
vws = {}
345358
for k in predicted_keys:
359+
if k not in preds[0]:
360+
logger.warn(f"Cannot generate video for prediction key '{k}'")
361+
continue
346362
cache_dir = (
347363
None
348364
if video_frames_dir is None
@@ -355,13 +371,15 @@ def _generate_prediction_videos(
355371
)
356372

357373
for rendered_pred in tqdm(preds):
358-
for k in predicted_keys:
374+
for k in vws:
359375
vws[k].write_frame(
360376
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
361377
resize=resize,
362378
)
363379

364380
for k in predicted_keys:
381+
if k not in vws:
382+
continue
365383
vws[k].get_video()
366384
logger.info(f"Generated {vws[k].out_path}.")
367385
if viz is not None:

0 commit comments

Comments
 (0)