1818from pytorch3d .implicitron .dataset .utils import is_known_frame , is_train_frame
1919from pytorch3d .implicitron .models .base_model import ImplicitronRender
2020from pytorch3d .implicitron .tools import vis_utils
21- from pytorch3d .implicitron .tools .camera_utils import volumetric_camera_overlaps
2221from pytorch3d .implicitron .tools .image_utils import mask_background
2322from pytorch3d .implicitron .tools .metric_utils import calc_psnr , eval_depth , iou , rgb_l1
2423from pytorch3d .implicitron .tools .point_cloud_utils import get_rgbd_point_cloud
2524from pytorch3d .implicitron .tools .vis_utils import make_depth_image
26- from pytorch3d .renderer .camera_utils import join_cameras_as_batch
27- from pytorch3d .renderer .cameras import CamerasBase , PerspectiveCameras
25+ from pytorch3d .renderer .cameras import PerspectiveCameras
2826from pytorch3d .vis .plotly_vis import plot_scene
2927from tabulate import tabulate
3028
@@ -149,7 +147,6 @@ def eval_batch(
149147 visualize : bool = False ,
150148 visualize_visdom_env : str = "eval_debug" ,
151149 break_after_visualising : bool = True ,
152- source_cameras : Optional [CamerasBase ] = None ,
153150) -> Dict [str , Any ]:
154151 """
155152 Produce performance metrics for a single batch of new-view synthesis
@@ -171,8 +168,6 @@ def eval_batch(
171168 ground truth.
172169 lpips_model: A pre-trained model for evaluating the LPIPS metric.
173170 visualize: If True, visualizes the results to Visdom.
174- source_cameras: A list of all training cameras for evaluating the
175- difficulty of the target views.
176171
177172 Returns:
178173 results: A dictionary holding evaluation metrics.
@@ -365,16 +360,7 @@ def eval_batch(
365360 # convert all metrics to floats
366361 results = {k : float (v ) for k , v in results .items ()}
367362
368- if source_cameras is None :
369- # pyre-fixme[16]: Optional has no attribute __getitem__
370- source_cameras = frame_data .camera [torch .where (is_known )[0 ]]
371-
372363 results ["meta" ] = {
373- # calculate the camera difficulties and add to results
374- "camera_difficulty" : calculate_camera_difficulties (
375- frame_data .camera [0 ],
376- source_cameras ,
377- )[0 ].item (),
378364 # store the size of the batch (corresponds to n_src_views+1)
379365 "batch_size" : int (is_known .numel ()),
380366 # store the type of the target frame
@@ -406,33 +392,6 @@ def average_per_batch_results(
406392 }
407393
408394
409- def calculate_camera_difficulties (
410- cameras_target : CamerasBase ,
411- cameras_source : CamerasBase ,
412- ) -> torch .Tensor :
413- """
414- Calculate the difficulties of the target cameras, given a set of known
415- cameras `cameras_source`.
416-
417- Returns:
418- a tensor of shape (len(cameras_target),)
419- """
420- ious = [
421- volumetric_camera_overlaps (
422- join_cameras_as_batch (
423- # pyre-fixme[6]: Expected `CamerasBase` for 1st param but got
424- # `Optional[pytorch3d.renderer.utils.TensorProperties]`.
425- [cameras_target [cami ], cameras_source .to (cameras_target .device )]
426- )
427- )[0 , :]
428- for cami in range (cameras_target .R .shape [0 ])
429- ]
430- camera_difficulties = torch .stack (
431- [_reduce_camera_iou_overlap (iou [1 :]) for iou in ious ]
432- )
433- return camera_difficulties
434-
435-
436395def _reduce_camera_iou_overlap (ious : torch .Tensor , topk : int = 2 ) -> torch .Tensor :
437396 """
438397 Calculate the final camera difficulty by computing the average of the
@@ -458,8 +417,7 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
458417def summarize_nvs_eval_results (
459418 per_batch_eval_results : List [Dict [str , Any ]],
460419 is_multisequence : bool ,
461- camera_difficulty_bin_breaks : Tuple [float , float ],
462- ):
420+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
463421 """
464422 Compile the per-batch evaluation results `per_batch_eval_results` into
465423 a set of aggregate metrics. The produced metrics depend on is_multisequence.
@@ -482,19 +440,12 @@ def summarize_nvs_eval_results(
482440 batch_sizes = torch .tensor (
483441 [r ["meta" ]["batch_size" ] for r in per_batch_eval_results ]
484442 ).long ()
485- camera_difficulty = torch .tensor (
486- [r ["meta" ]["camera_difficulty" ] for r in per_batch_eval_results ]
487- ).float ()
443+
488444 is_train = is_train_frame ([r ["meta" ]["frame_type" ] for r in per_batch_eval_results ])
489445
490446 # init the result database dict
491447 results = []
492448
493- diff_bin_edges , diff_bin_names = _get_camera_difficulty_bin_edges (
494- camera_difficulty_bin_breaks
495- )
496- n_diff_edges = diff_bin_edges .numel ()
497-
498449 # add per set averages
499450 for SET in eval_sets :
500451 if SET is None :
@@ -504,26 +455,17 @@ def summarize_nvs_eval_results(
504455 ok_set = is_train == int (SET == "train" )
505456 set_name = SET
506457
507- # eval each difficulty bin, including a full average result (diff_bin=None)
508- for diff_bin in [None , * list (range (n_diff_edges - 1 ))]:
509- if diff_bin is None :
510- # average over all results
511- in_bin = ok_set
512- diff_bin_name = "all"
513- else :
514- b1 , b2 = diff_bin_edges [diff_bin : (diff_bin + 2 )]
515- in_bin = ok_set & (camera_difficulty > b1 ) & (camera_difficulty <= b2 )
516- diff_bin_name = diff_bin_names [diff_bin ]
517- bin_results = average_per_batch_results (
518- per_batch_eval_results , idx = torch .where (in_bin )[0 ]
519- )
520- results .append (
521- {
522- "subset" : set_name ,
523- "subsubset" : f"diff={ diff_bin_name } " ,
524- "metrics" : bin_results ,
525- }
526- )
458+ # average over all results
459+ bin_results = average_per_batch_results (
460+ per_batch_eval_results , idx = torch .where (ok_set )[0 ]
461+ )
462+ results .append (
463+ {
464+ "subset" : set_name ,
465+ "subsubset" : "diff=all" ,
466+ "metrics" : bin_results ,
467+ }
468+ )
527469
528470 if is_multisequence :
529471 # split based on n_src_views
@@ -552,7 +494,7 @@ def _get_flat_nvs_metric_key(result, metric_name) -> str:
552494 return metric_key
553495
554496
555- def flatten_nvs_results (results ):
497+ def flatten_nvs_results (results ) -> Dict [ str , Any ] :
556498 """
557499 Takes input `results` list of dicts of the form::
558500
@@ -571,7 +513,6 @@ def flatten_nvs_results(results):
571513 'subset=train/test/...|subsubset=src=1/src=2/...': nvs_eval_metrics,
572514 ...
573515 }
574-
575516 """
576517 results_flat = {}
577518 for result in results :
0 commit comments