Skip to content

Commit 63ba74f

Browse files
megluyagaofacebook-github-bot
authored andcommitted
Return R2N2 voxel coordinates
Summary: Return R2N2's voxel coordinates. Reviewed By: nikhilaravi Differential Revision: D22462530 fbshipit-source-id: a995cfa0957b2561eb3b0f4591cb1db42170bc68
1 parent 326e4cc commit 63ba74f

File tree

7 files changed

+602
-135
lines changed

7 files changed

+602
-135
lines changed

pytorch3d/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from .r2n2 import R2N2, BlenderCamera
3+
from .r2n2 import R2N2, BlenderCamera, collate_batched_R2N2, render_cubified_voxels
44
from .shapenet import ShapeNetCore
55
from .utils import collate_batched_meshes
66

pytorch3d/datasets/r2n2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from .r2n2 import R2N2, BlenderCamera
3+
from .r2n2 import R2N2
4+
from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels
45

56

67
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/datasets/r2n2/r2n2.py

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,32 @@
1010
import torch
1111
from PIL import Image
1212
from pytorch3d.datasets.shapenet_base import ShapeNetBase
13-
from pytorch3d.datasets.utils import compute_extrinsic_matrix
1413
from pytorch3d.io import load_obj
1514
from pytorch3d.renderer import HardPhongShader
16-
from pytorch3d.renderer.cameras import CamerasBase
17-
from pytorch3d.transforms import Transform3d
1815
from tabulate import tabulate
1916

17+
from .utils import (
18+
BlenderCamera,
19+
align_bbox,
20+
compute_extrinsic_matrix,
21+
read_binvox_coords,
22+
voxelize,
23+
)
2024

21-
SYNSET_DICT_DIR = Path(__file__).resolve().parent
2225

23-
# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
24-
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
25-
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
26-
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
26+
SYNSET_DICT_DIR = Path(__file__).resolve().parent
27+
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
28+
VOXEL_SIZE = 128
29+
# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase:
30+
# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
31+
BLENDER_INTRINSIC = torch.tensor(
32+
[
33+
[2.1875, 0.0, 0.0, 0.0],
34+
[0.0, 2.1875, 0.0, 0.0],
35+
[0.0, 0.0, -1.002002, -0.2002002],
36+
[0.0, 0.0, -1.0, 0.0],
37+
]
38+
)
2739

2840

2941
class R2N2(ShapeNetBase):
@@ -42,6 +54,7 @@ def __init__(
4254
r2n2_dir,
4355
splits_file,
4456
return_all_views: bool = True,
57+
return_voxels: bool = False,
4558
):
4659
"""
4760
Store each object's synset id and models id the given directories.
@@ -54,6 +67,8 @@ def __init__(
5467
return_all_views (bool): Indicator of whether or not to load all the views in
5568
the split. If set to False, one of the views in the split will be randomly
5669
selected and loaded.
70+
return_voxels(bool): Indicator of whether or not to return voxels as a tensor
71+
of shape (D, D, D) where D is the number of voxels along each dimension.
5772
"""
5873
super().__init__()
5974
self.shapenet_dir = shapenet_dir
@@ -83,6 +98,16 @@ def __init__(
8398
) % (r2n2_dir)
8499
warnings.warn(msg)
85100

101+
self.return_voxels = return_voxels
102+
# Check if the folder containing voxel coordinates is included in r2n2_dir.
103+
if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")):
104+
self.return_voxels = False
105+
msg = (
106+
"ShapeNetVox32 not found in %s. Voxel coordinates will "
107+
"be skipped when returning models."
108+
) % (r2n2_dir)
109+
warnings.warn(msg)
110+
86111
synset_set = set()
87112
# Store lists of views of each model in a list.
88113
self.views_per_model_list = []
@@ -173,6 +198,8 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
173198
- R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
174199
- T: Translation matrix of shape (V, 3), where V is number of views returned.
175200
- K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
201+
- voxels: Voxels of shape (D, D, D), where D is the number of voxels along each
202+
dimension.
176203
"""
177204
if isinstance(model_idx, tuple):
178205
model_idx, view_idxs = model_idx
@@ -208,6 +235,7 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
208235
model["label"] = self.synset_dict[model["synset_id"]]
209236

210237
model["images"] = None
238+
images, Rs, Ts, voxel_RTs = [], [], [], []
211239
# Retrieve R2N2's renderings if required.
212240
if self.return_images:
213241
rendering_path = path.join(
@@ -217,12 +245,9 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
217245
model["model_id"],
218246
"rendering",
219247
)
220-
221248
# Read metadata file to obtain params for calibration matrices.
222249
with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
223250
metadata_lines = f.readlines()
224-
225-
images, Rs, Ts = [], [], []
226251
for i in model_views:
227252
# Read image.
228253
image_path = path.join(rendering_path, "%02d.png" % i)
@@ -234,9 +259,13 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
234259
azim, elev, yaw, dist_ratio, fov = [
235260
float(v) for v in metadata_lines[i].strip().split(" ")
236261
]
237-
R, T = self._compute_camera_calibration(azim, elev, dist_ratio)
262+
dist = dist_ratio * MAX_CAMERA_DISTANCE
263+
# Extrinsic matrix before transformation to PyTorch3D world space.
264+
RT = compute_extrinsic_matrix(azim, elev, dist)
265+
R, T = self._compute_camera_calibration(RT)
238266
Rs.append(R)
239267
Ts.append(T)
268+
voxel_RTs.append(RT)
240269

241270
# Intrinsic matrix extracted from the Blender with slight modification to work with
242271
# PyTorch3D world space. Taken from meshrcnn codebase:
@@ -254,27 +283,48 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
254283
model["T"] = torch.stack(Ts)
255284
model["K"] = K.expand(len(model_views), 4, 4)
256285

286+
voxels_list = []
287+
# Read voxels if required.
288+
voxel_path = path.join(
289+
self.r2n2_dir,
290+
"ShapeNetVox32",
291+
model["synset_id"],
292+
model["model_id"],
293+
"model.binvox",
294+
)
295+
if self.return_voxels:
296+
if not path.isfile(voxel_path):
297+
msg = "Voxel file not found for model %s from category %s."
298+
raise FileNotFoundError(msg % (model["model_id"], model["synset_id"]))
299+
300+
with open(voxel_path, "rb") as f:
301+
# Read voxel coordinates as a tensor of shape (N, 3).
302+
voxel_coords = read_binvox_coords(f)
303+
# Align voxels to the same coordinate system as mesh verts.
304+
voxel_coords = align_bbox(voxel_coords, model["verts"])
305+
for RT in voxel_RTs:
306+
# Compute projection matrix.
307+
P = BLENDER_INTRINSIC.mm(RT)
308+
# Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D).
309+
voxels = voxelize(voxel_coords, P, VOXEL_SIZE)
310+
voxels_list.append(voxels)
311+
model["voxels"] = torch.stack(voxels_list)
312+
257313
return model
258314

259-
def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float):
315+
def _compute_camera_calibration(self, RT):
260316
"""
261-
Helper function for calculating rotation and translation matrices from azimuth
262-
angle, elevation and distance ratio.
317+
Helper function for calculating rotation and translation matrices from ShapeNet
318+
to camera transformation and ShapeNet to PyTorch3D transformation.
263319
264320
Args:
265-
azim: Rotation about the z-axis, in degrees.
266-
elev: Rotation above the xy-plane, in degrees.
267-
dist_ratio: Ratio of distance from the origin to the maximum camera distance.
321+
RT: Extrinsic matrix that performs ShapeNet world view to camera view
322+
transformation.
268323
269324
Returns:
270-
- R: Rotation matrix of shape (3, 3).
271-
- T: Translation matrix of shape (3).
325+
R: Rotation matrix of shape (3, 3).
326+
T: Translation matrix of shape (3).
272327
"""
273-
# Retrive R,T,K of the selected view(s) by reading the metadata.
274-
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
275-
dist = dist_ratio * MAX_CAMERA_DISTANCE
276-
RT = compute_extrinsic_matrix(azim, elev, dist)
277-
278328
# Transform the mesh vertices from shapenet world to pytorch3d world.
279329
shapenet_to_pytorch3d = torch.tensor(
280330
[
@@ -285,9 +335,7 @@ def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: floa
285335
],
286336
dtype=torch.float32,
287337
)
288-
RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4)
289338
RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4)
290-
291339
# Extract rotation and translation matrices from RT.
292340
R = RT[:3, :3]
293341
T = RT[3, :3]
@@ -348,27 +396,3 @@ def render(
348396
return super().render(
349397
idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
350398
)
351-
352-
353-
class BlenderCamera(CamerasBase):
354-
"""
355-
Camera for rendering objects with calibration matrices from the R2N2 dataset
356-
(which uses Blender for rendering the views for each model).
357-
"""
358-
359-
def __init__(self, R=r, T=t, K=k, device="cpu"):
360-
"""
361-
Args:
362-
R: Rotation matrix of shape (N, 3, 3).
363-
T: Translation matrix of shape (N, 3).
364-
K: Intrinsic matrix of shape (N, 4, 4).
365-
device: torch.device or str.
366-
"""
367-
# The initializer formats all inputs to torch tensors and broadcasts
368-
# all the inputs to have the same batch dimension where necessary.
369-
super().__init__(device=device, R=R, T=T, K=K)
370-
371-
def get_projection_transform(self, **kwargs) -> Transform3d:
372-
transform = Transform3d(device=self.device)
373-
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
374-
return transform

0 commit comments

Comments
 (0)