|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 3 | +import os |
| 4 | +import warnings |
| 5 | + |
| 6 | +import hydra |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +from nerf.dataset import get_nerf_datasets, trivial_collate |
| 10 | +from nerf.eval_video_utils import generate_eval_video_cameras |
| 11 | +from nerf.nerf_renderer import RadianceFieldRenderer |
| 12 | +from nerf.stats import Stats |
| 13 | +from omegaconf import DictConfig |
| 14 | +from PIL import Image |
| 15 | + |
| 16 | +CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") |
| 17 | + |
| 18 | + |
| 19 | +@hydra.main(config_path=CONFIG_DIR, config_name="lego") |
| 20 | +def main(cfg: DictConfig): |
| 21 | + |
| 22 | + # Device on which to run. |
| 23 | + if torch.cuda.is_available(): |
| 24 | + device = "cuda" |
| 25 | + else: |
| 26 | + warnings.warn( |
| 27 | + "Please note that although executing on CPU is supported," |
| 28 | + + "the testing is unlikely to finish in reasonable time." |
| 29 | + ) |
| 30 | + device = "cpu" |
| 31 | + |
| 32 | + # Initialize the Radiance Field model. |
| 33 | + model = RadianceFieldRenderer( |
| 34 | + image_size=cfg.data.image_size, |
| 35 | + n_pts_per_ray=cfg.raysampler.n_pts_per_ray, |
| 36 | + n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, |
| 37 | + n_rays_per_image=cfg.raysampler.n_rays_per_image, |
| 38 | + min_depth=cfg.raysampler.min_depth, |
| 39 | + max_depth=cfg.raysampler.max_depth, |
| 40 | + stratified=cfg.raysampler.stratified, |
| 41 | + stratified_test=cfg.raysampler.stratified_test, |
| 42 | + chunk_size_test=cfg.raysampler.chunk_size_test, |
| 43 | + n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, |
| 44 | + n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, |
| 45 | + n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, |
| 46 | + n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, |
| 47 | + n_layers_xyz=cfg.implicit_function.n_layers_xyz, |
| 48 | + density_noise_std=cfg.implicit_function.density_noise_std, |
| 49 | + ) |
| 50 | + |
| 51 | + # Move the model to the relevant device. |
| 52 | + model.to(device) |
| 53 | + |
| 54 | + # Resume from the checkpoint. |
| 55 | + checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) |
| 56 | + if not os.path.isfile(checkpoint_path): |
| 57 | + raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!") |
| 58 | + |
| 59 | + print(f"Loading checkpoint {checkpoint_path}.") |
| 60 | + loaded_data = torch.load(checkpoint_path) |
| 61 | + # Do not load the cached xy grid. |
| 62 | + # - this allows to set an arbitrary evaluation image size. |
| 63 | + state_dict = { |
| 64 | + k: v |
| 65 | + for k, v in loaded_data["model"].items() |
| 66 | + if "_grid_raysampler._xy_grid" not in k |
| 67 | + } |
| 68 | + model.load_state_dict(state_dict, strict=False) |
| 69 | + |
| 70 | + # Load the test data. |
| 71 | + if cfg.test.mode == "evaluation": |
| 72 | + _, _, test_dataset = get_nerf_datasets( |
| 73 | + dataset_name=cfg.data.dataset_name, |
| 74 | + image_size=cfg.data.image_size, |
| 75 | + ) |
| 76 | + elif cfg.test.mode == "export_video": |
| 77 | + train_dataset, _, _ = get_nerf_datasets( |
| 78 | + dataset_name=cfg.data.dataset_name, |
| 79 | + image_size=cfg.data.image_size, |
| 80 | + ) |
| 81 | + test_dataset = generate_eval_video_cameras( |
| 82 | + train_dataset, |
| 83 | + trajectory_type=cfg.test.trajectory_type, |
| 84 | + up=cfg.test.up, |
| 85 | + scene_center=cfg.test.scene_center, |
| 86 | + n_eval_cams=cfg.test.n_frames, |
| 87 | + trajectory_scale=cfg.test.trajectory_scale, |
| 88 | + ) |
| 89 | + # store the video in directory (checkpoint_file - extension + '_video') |
| 90 | + export_dir = os.path.splitext(checkpoint_path)[0] + "_video" |
| 91 | + os.makedirs(export_dir, exist_ok=True) |
| 92 | + else: |
| 93 | + raise ValueError(f"Unknown test mode {cfg.test_mode}.") |
| 94 | + |
| 95 | + # Init the test dataloader. |
| 96 | + test_dataloader = torch.utils.data.DataLoader( |
| 97 | + test_dataset, |
| 98 | + batch_size=1, |
| 99 | + shuffle=False, |
| 100 | + num_workers=0, |
| 101 | + collate_fn=trivial_collate, |
| 102 | + ) |
| 103 | + |
| 104 | + if cfg.test.mode == "evaluation": |
| 105 | + # Init the test stats object. |
| 106 | + eval_stats = ["mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"] |
| 107 | + stats = Stats(eval_stats) |
| 108 | + stats.new_epoch() |
| 109 | + elif cfg.test.mode == "export_video": |
| 110 | + # Init the frame buffer. |
| 111 | + frame_paths = [] |
| 112 | + |
| 113 | + # Set the model to the eval mode. |
| 114 | + model.eval() |
| 115 | + |
| 116 | + # Run the main testing loop. |
| 117 | + for batch_idx, test_batch in enumerate(test_dataloader): |
| 118 | + test_image, test_camera, camera_idx = test_batch[0].values() |
| 119 | + if test_image is not None: |
| 120 | + test_image = test_image.to(device) |
| 121 | + test_camera = test_camera.to(device) |
| 122 | + |
| 123 | + # Activate eval mode of the model (allows to do a full rendering pass). |
| 124 | + model.eval() |
| 125 | + with torch.no_grad(): |
| 126 | + test_nerf_out, test_metrics = model( |
| 127 | + None, # we do not use pre-cached cameras |
| 128 | + test_camera, |
| 129 | + test_image, |
| 130 | + ) |
| 131 | + |
| 132 | + if cfg.test.mode == "evaluation": |
| 133 | + # Update stats with the validation metrics. |
| 134 | + stats.update(test_metrics, stat_set="test") |
| 135 | + stats.print(stat_set="test") |
| 136 | + |
| 137 | + elif cfg.test.mode == "export_video": |
| 138 | + # Store the video frame. |
| 139 | + frame = test_nerf_out["rgb_fine"][0].detach().cpu() |
| 140 | + frame_path = os.path.join(export_dir, f"frame_{batch_idx:05d}.png") |
| 141 | + print(f"Writing {frame_path}.") |
| 142 | + Image.fromarray((frame.numpy() * 255.0).astype(np.uint8)).save(frame_path) |
| 143 | + frame_paths.append(frame_path) |
| 144 | + |
| 145 | + if cfg.test.mode == "evaluation": |
| 146 | + print(f"Final evaluation metrics on '{cfg.data.dataset_name}':") |
| 147 | + for stat in eval_stats: |
| 148 | + stat_value = stats.stats["test"][stat].get_epoch_averages()[0] |
| 149 | + print(f"{stat:15s}: {stat_value:1.4f}") |
| 150 | + |
| 151 | + elif cfg.test.mode == "export_video": |
| 152 | + # Convert the exported frames to a video. |
| 153 | + video_path = os.path.join(export_dir, "video.mp4") |
| 154 | + ffmpeg_bin = "ffmpeg" |
| 155 | + frame_regexp = os.path.join(export_dir, "frame_%05d.png") |
| 156 | + ffmcmd = ( |
| 157 | + "%s -r %d -i %s -vcodec h264 -f mp4 -y -b 2000k -pix_fmt yuv420p %s" |
| 158 | + % (ffmpeg_bin, cfg.test.fps, frame_regexp, video_path) |
| 159 | + ) |
| 160 | + ret = os.system(ffmcmd) |
| 161 | + if ret != 0: |
| 162 | + raise RuntimeError("ffmpeg failed!") |
| 163 | + |
| 164 | + |
| 165 | +if __name__ == "__main__": |
| 166 | + main() |
0 commit comments