diff --git a/docs/examples/.gitignore b/docs/examples/.gitignore new file mode 100644 index 000000000..aab52d906 --- /dev/null +++ b/docs/examples/.gitignore @@ -0,0 +1 @@ +*.png \ No newline at end of file diff --git a/docs/examples/pulsar_opencv.py b/docs/examples/pulsar_opencv.py new file mode 100755 index 000000000..9d21621ae --- /dev/null +++ b/docs/examples/pulsar_opencv.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates OpenCV camera parameter with the plain +pulsar interface. +""" +import logging + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pytorch3d.renderer.points.pulsar as pulsar +import torch +from pytorch3d.renderer.points.pulsar.camera import opencv2pulsar + + +matplotlib.use("Agg") + + +def cli(): + """ + Basic example for the OpenCV-to-Pulsar conversion. + + Writes to `opencv2pulsar.png`. + """ + # ~~~~~~~~~~~~~~~~~ + # sample 3d points + points3d = np.array( + [ + [0.128826, -0.347764, 1.62346], + [0.136779, -0.197784, 1.833705], + [-0.038932, -0.189967, 1.830946], + [-0.084399, 0.105825, 1.878489], + [-0.082497, 0.358484, 1.809373], + [0.310953, -0.203041, 1.828439], + [0.363599, 0.086033, 1.858132], + [0.347989, 0.34087, 1.802693], + [0.136886, 0.3853, 1.835586], + ] + ) + n_spheres = len(points3d) + + # ~~~~~~~~~~~~~~~~~ + # camera params + zfar = 10.0 + znear = 0.1 + h = 1024 + w = 1024 + f = 1127.64 + cx = 516.12 + cy = 510.58 + + K = np.eye(3) + K[0, 2] = cx + K[1, 2] = cy + K[0, 0] = f + K[1, 1] = f + + rvec = np.array( + [[-0.051111404817219305, -2.6377198366878027, -0.28602826581257784]] + ) + C = np.array([[-0.482771, -0.400003, 3.479192]]).transpose() + + R = cv2.Rodrigues(rvec)[0] + tvec = -R @ C + + # ~~~~~~~~~~~~~~~~~ + # OpenCV projection + distCoef = np.zeros((5,)) + points2d_opencv, _ = cv2.projectPoints(points3d, rvec, tvec, K, distCoef) + points2d_opencv = np.squeeze(points2d_opencv) + + # ~~~~~~~~~~~~~~~~~ + # Pulsar projection + cam_params = opencv2pulsar(K, R, tvec, h, w) + + # We're working with a default left handed system here. + renderer = pulsar.Renderer(w, h, n_spheres, right_handed_system=False) + + pos = torch.from_numpy(points3d).float().cpu() + + col = torch.zeros((n_spheres, 3)).cpu().float() + col[:, 0] = 1.0 + rad = torch.ones((n_spheres,)).cpu().float() * 0.02 + image_pulsar = renderer( + pos, + col, + rad, + cam_params, + 1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5]. + max_depth=zfar, # Maximum depth. + min_depth=znear, + ) + image_pulsar = (image_pulsar.cpu().numpy() * 255).astype("uint8") + + # Flip the image horizontal + image_pulsar = image_pulsar[::-1, :, :] + + # ~~~~~~~~~~~~~~~~~ + # Plotting to Figure + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111) + ax.set_xlim([0, w]) + ax.set_ylim([h, 0]) + ax.imshow(image_pulsar) + ax.scatter(points2d_opencv[:, 0], points2d_opencv[:, 1], color="blue", alpha=0.5) + + plt.tight_layout() + plt.savefig("opencv2pulsar.png") + plt.close() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + cli() diff --git a/pytorch3d/renderer/points/pulsar/camera.py b/pytorch3d/renderer/points/pulsar/camera.py new file mode 100644 index 000000000..f19e7ea89 --- /dev/null +++ b/pytorch3d/renderer/points/pulsar/camera.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""pulsar renderer Camera utils. +""" +import numpy as np +import torch +from pytorch3d.transforms.so3 import so3_log_map + + +def opencv2pulsar( + K: torch.Tensor, + R: torch.Tensor, + T: torch.Tensor, + h: int, + w: int, + znear: float = 0.1, +) -> torch.Tensor: + """ + Convert OpenCV style camera parameters to Pulsar style cameras. + + !!IMPORTANT!! + * Pulsar does NOT support different focal lengths for x and y yet so + we simply take the average of fx and fy. + * The Pulsar renderer MUST use a left-handed coordinate system + * The resulting image will be horizontally flipped - which has to be + addressed AFTER rendering by the user + + Args: + * K: intrinsic camera parameters. [Bx]3x3. + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + * R: camera rotation in world coors. [Bx]3x3. + * T: camera translation in world coords. [bx]3x1 + * h: image height + * w: image width + * znear: defines near clipping plane + """ + # users may pass numpy arrays rather than torch tensors + if isinstance(K, np.ndarray): + K = torch.from_numpy(K).float() + if isinstance(R, np.ndarray): + R = torch.from_numpy(R).float() + if isinstance(T, np.ndarray): + T = torch.from_numpy(T).float() + + device = K.device + + # test if the data is batched or not using `K` + # assume that all passed parameters are either + # all batched or NOT batched at all. + input_is_not_batched = len(K.size()) == 2 + if input_is_not_batched: + K = K.unsqueeze(0) + R = R.unsqueeze(0) + T = T.unsqueeze(0) + if len(T.size()) == 2: + T = T.unsqueeze(2) # make T a col vector + + # verify parameters + assert h > 0 and w > 0, "height and width must be positive but are: %d, %d" % (h, w) + assert ( + K.size(1) == 3 and K.size(2) == 3 + ), "Incorrect intrinsic shape: expected 3x3 but got %dx%d" % (K.size(1), K.size(2)) + assert ( + R.size(1) == 3 and R.size(2) == 3 + ), "Incorrect R shape: expected 3x3 but got %dx%d" % (R.size(1), R.size(2)) + assert ( + T.size(1) == 3 and T.size(2) == 1 + ), "Incorrect T shape: expected 3x1 but got %dx%d" % (T.size(1), T.size(2)) + + batch_size = K.size(0) + + fx = K[:, 0, 0].unsqueeze(1) + fy = K[:, 1, 1].unsqueeze(1) + f = (fx + fy) / 2 + + # Normalize f into normalized device coordinates (NDC). + focal_length_px = f / w + + # Transfer into focal_length and sensor_width. + focal_length = torch.tensor( + [ + [ + znear - 1e-5, + ] + ], + dtype=torch.float32, + device=device, + ).repeat(batch_size, 1) + sensor_width = focal_length / focal_length_px + + cx = K[:, 0, 2].unsqueeze(1) + cy = K[:, 1, 2].unsqueeze(1) + + # transfer principal point offset into centered offset + cx = -(cx - w / 2) + cy = cy - h / 2 + + param = torch.cat([focal_length, sensor_width, cx, cy], dim=1) + + R_trans = R.permute(0, 2, 1) + + cam_pos = -torch.bmm(R_trans, T).squeeze(2) + + cam_rot = so3_log_map(R_trans) + + cam_params = torch.cat([cam_pos, cam_rot, param], dim=1) + + if input_is_not_batched: + # un-batch params + cam_params = cam_params[0] + + return cam_params