Skip to content

Commit 2628fb5

Browse files
davnov134facebook-github-bot
authored andcommitted
Testing script
Summary: Implements the test script of NeRF. Reviewed By: nikhilaravi Differential Revision: D25684450 fbshipit-source-id: 739169d9df706795814912bb9a15e2e65ac92df8
1 parent dc28b61 commit 2628fb5

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

projects/nerf/configs/fern.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ test:
1515
scene_center: [0.0, 0.0, -2.0]
1616
n_frames: 100
1717
fps: 20
18+
trajectory_scale: 1.0
1819
optimizer:
1920
max_epochs: 37500
2021
lr: 0.0005

projects/nerf/configs/lego.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ test:
1515
scene_center: [0.0, 0.0, 0.0]
1616
n_frames: 100
1717
fps: 20
18+
trajectory_scale: 0.2
1819
optimizer:
1920
max_epochs: 20000
2021
lr: 0.0005

projects/nerf/configs/pt3logo.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ test:
1515
scene_center: [0.0, 0.0, 0.0]
1616
n_frames: 100
1717
fps: 20
18+
trajectory_scale: 0.2
1819
optimizer:
1920
max_epochs: 100000
2021
lr: 0.0005

projects/nerf/test_nerf.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3+
import os
4+
import warnings
5+
6+
import hydra
7+
import numpy as np
8+
import torch
9+
from nerf.dataset import get_nerf_datasets, trivial_collate
10+
from nerf.eval_video_utils import generate_eval_video_cameras
11+
from nerf.nerf_renderer import RadianceFieldRenderer
12+
from nerf.stats import Stats
13+
from omegaconf import DictConfig
14+
from PIL import Image
15+
16+
CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
17+
18+
19+
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
20+
def main(cfg: DictConfig):
21+
22+
# Device on which to run.
23+
if torch.cuda.is_available():
24+
device = "cuda"
25+
else:
26+
warnings.warn(
27+
"Please note that although executing on CPU is supported,"
28+
+ "the testing is unlikely to finish in reasonable time."
29+
)
30+
device = "cpu"
31+
32+
# Initialize the Radiance Field model.
33+
model = RadianceFieldRenderer(
34+
image_size=cfg.data.image_size,
35+
n_pts_per_ray=cfg.raysampler.n_pts_per_ray,
36+
n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray,
37+
n_rays_per_image=cfg.raysampler.n_rays_per_image,
38+
min_depth=cfg.raysampler.min_depth,
39+
max_depth=cfg.raysampler.max_depth,
40+
stratified=cfg.raysampler.stratified,
41+
stratified_test=cfg.raysampler.stratified_test,
42+
chunk_size_test=cfg.raysampler.chunk_size_test,
43+
n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz,
44+
n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir,
45+
n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz,
46+
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
47+
n_layers_xyz=cfg.implicit_function.n_layers_xyz,
48+
density_noise_std=cfg.implicit_function.density_noise_std,
49+
)
50+
51+
# Move the model to the relevant device.
52+
model.to(device)
53+
54+
# Resume from the checkpoint.
55+
checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path)
56+
if not os.path.isfile(checkpoint_path):
57+
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
58+
59+
print(f"Loading checkpoint {checkpoint_path}.")
60+
loaded_data = torch.load(checkpoint_path)
61+
# Do not load the cached xy grid.
62+
# - this allows to set an arbitrary evaluation image size.
63+
state_dict = {
64+
k: v
65+
for k, v in loaded_data["model"].items()
66+
if "_grid_raysampler._xy_grid" not in k
67+
}
68+
model.load_state_dict(state_dict, strict=False)
69+
70+
# Load the test data.
71+
if cfg.test.mode == "evaluation":
72+
_, _, test_dataset = get_nerf_datasets(
73+
dataset_name=cfg.data.dataset_name,
74+
image_size=cfg.data.image_size,
75+
)
76+
elif cfg.test.mode == "export_video":
77+
train_dataset, _, _ = get_nerf_datasets(
78+
dataset_name=cfg.data.dataset_name,
79+
image_size=cfg.data.image_size,
80+
)
81+
test_dataset = generate_eval_video_cameras(
82+
train_dataset,
83+
trajectory_type=cfg.test.trajectory_type,
84+
up=cfg.test.up,
85+
scene_center=cfg.test.scene_center,
86+
n_eval_cams=cfg.test.n_frames,
87+
trajectory_scale=cfg.test.trajectory_scale,
88+
)
89+
# store the video in directory (checkpoint_file - extension + '_video')
90+
export_dir = os.path.splitext(checkpoint_path)[0] + "_video"
91+
os.makedirs(export_dir, exist_ok=True)
92+
else:
93+
raise ValueError(f"Unknown test mode {cfg.test_mode}.")
94+
95+
# Init the test dataloader.
96+
test_dataloader = torch.utils.data.DataLoader(
97+
test_dataset,
98+
batch_size=1,
99+
shuffle=False,
100+
num_workers=0,
101+
collate_fn=trivial_collate,
102+
)
103+
104+
if cfg.test.mode == "evaluation":
105+
# Init the test stats object.
106+
eval_stats = ["mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"]
107+
stats = Stats(eval_stats)
108+
stats.new_epoch()
109+
elif cfg.test.mode == "export_video":
110+
# Init the frame buffer.
111+
frame_paths = []
112+
113+
# Set the model to the eval mode.
114+
model.eval()
115+
116+
# Run the main testing loop.
117+
for batch_idx, test_batch in enumerate(test_dataloader):
118+
test_image, test_camera, camera_idx = test_batch[0].values()
119+
if test_image is not None:
120+
test_image = test_image.to(device)
121+
test_camera = test_camera.to(device)
122+
123+
# Activate eval mode of the model (allows to do a full rendering pass).
124+
model.eval()
125+
with torch.no_grad():
126+
test_nerf_out, test_metrics = model(
127+
None, # we do not use pre-cached cameras
128+
test_camera,
129+
test_image,
130+
)
131+
132+
if cfg.test.mode == "evaluation":
133+
# Update stats with the validation metrics.
134+
stats.update(test_metrics, stat_set="test")
135+
stats.print(stat_set="test")
136+
137+
elif cfg.test.mode == "export_video":
138+
# Store the video frame.
139+
frame = test_nerf_out["rgb_fine"][0].detach().cpu()
140+
frame_path = os.path.join(export_dir, f"frame_{batch_idx:05d}.png")
141+
print(f"Writing {frame_path}.")
142+
Image.fromarray((frame.numpy() * 255.0).astype(np.uint8)).save(frame_path)
143+
frame_paths.append(frame_path)
144+
145+
if cfg.test.mode == "evaluation":
146+
print(f"Final evaluation metrics on '{cfg.data.dataset_name}':")
147+
for stat in eval_stats:
148+
stat_value = stats.stats["test"][stat].get_epoch_averages()[0]
149+
print(f"{stat:15s}: {stat_value:1.4f}")
150+
151+
elif cfg.test.mode == "export_video":
152+
# Convert the exported frames to a video.
153+
video_path = os.path.join(export_dir, "video.mp4")
154+
ffmpeg_bin = "ffmpeg"
155+
frame_regexp = os.path.join(export_dir, "frame_%05d.png")
156+
ffmcmd = (
157+
"%s -r %d -i %s -vcodec h264 -f mp4 -y -b 2000k -pix_fmt yuv420p %s"
158+
% (ffmpeg_bin, cfg.test.fps, frame_regexp, video_path)
159+
)
160+
ret = os.system(ffmcmd)
161+
if ret != 0:
162+
raise RuntimeError("ffmpeg failed!")
163+
164+
165+
if __name__ == "__main__":
166+
main()

0 commit comments

Comments
 (0)