|
36 | 36 |
|
37 | 37 | import numpy as np |
38 | 38 | import torch |
| 39 | +from pytorch3d.common.datatypes import Device |
39 | 40 | from pytorch3d.renderer.camera_utils import join_cameras_as_batch |
40 | 41 | from pytorch3d.renderer.cameras import ( |
41 | 42 | camera_position_from_spherical_angles, |
@@ -149,14 +150,17 @@ def ndc_to_screen_points_naive(points, imsize): |
149 | 150 |
|
150 | 151 |
|
151 | 152 | def init_random_cameras( |
152 | | - cam_type: typing.Type[CamerasBase], batch_size: int, random_z: bool = False |
| 153 | + cam_type: typing.Type[CamerasBase], |
| 154 | + batch_size: int, |
| 155 | + random_z: bool = False, |
| 156 | + device: Device = "cpu", |
153 | 157 | ): |
154 | 158 | cam_params = {} |
155 | 159 | T = torch.randn(batch_size, 3) * 0.03 |
156 | 160 | if not random_z: |
157 | 161 | T[:, 2] = 4 |
158 | 162 | R = so3_exp_map(torch.randn(batch_size, 3) * 3.0) |
159 | | - cam_params = {"R": R, "T": T} |
| 163 | + cam_params = {"R": R, "T": T, "device": device} |
160 | 164 | if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): |
161 | 165 | cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 |
162 | 166 | cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] |
@@ -613,15 +617,33 @@ def test_unproject_points(self, batch_size=50, num_points=100): |
613 | 617 | self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4)) |
614 | 618 |
|
615 | 619 | @staticmethod |
616 | | - def unproject_points(cam_type, batch_size=50, num_points=100): |
| 620 | + def unproject_points( |
| 621 | + cam_type, batch_size=50, num_points=100, device: Device = "cpu" |
| 622 | + ): |
617 | 623 | """ |
618 | 624 | Checks that an unprojection of a randomly projected point cloud |
619 | 625 | stays the same. |
620 | 626 | """ |
| 627 | + if device == "cuda": |
| 628 | + device = torch.device("cuda:0") |
| 629 | + else: |
| 630 | + device = torch.device("cpu") |
| 631 | + |
| 632 | + str2cls = { # noqa |
| 633 | + "OpenGLOrthographicCameras": OpenGLOrthographicCameras, |
| 634 | + "OpenGLPerspectiveCameras": OpenGLPerspectiveCameras, |
| 635 | + "SfMOrthographicCameras": SfMOrthographicCameras, |
| 636 | + "SfMPerspectiveCameras": SfMPerspectiveCameras, |
| 637 | + "FoVOrthographicCameras": FoVOrthographicCameras, |
| 638 | + "FoVPerspectiveCameras": FoVPerspectiveCameras, |
| 639 | + "OrthographicCameras": OrthographicCameras, |
| 640 | + "PerspectiveCameras": PerspectiveCameras, |
| 641 | + "FishEyeCameras": FishEyeCameras, |
| 642 | + } |
621 | 643 |
|
622 | 644 | def run_cameras(): |
623 | 645 | # init the cameras |
624 | | - cameras = init_random_cameras(cam_type, batch_size) |
| 646 | + cameras = init_random_cameras(str2cls[cam_type], batch_size, device=device) |
625 | 647 | # xyz - the ground truth point cloud |
626 | 648 | xyz = torch.randn(num_points, 3) * 0.3 |
627 | 649 | xyz = cameras.unproject_points(xyz, scaled_depth_input=True) |
@@ -666,15 +688,33 @@ def test_project_points_screen(self, batch_size=50, num_points=100): |
666 | 688 | self.assertClose(xyz_project_screen, xyz_project_screen_naive, atol=1e-4) |
667 | 689 |
|
668 | 690 | @staticmethod |
669 | | - def transform_points(cam_type, batch_size=50, num_points=100): |
| 691 | + def transform_points( |
| 692 | + cam_type, batch_size=50, num_points=100, device: Device = "cpu" |
| 693 | + ): |
670 | 694 | """ |
671 | 695 | Checks that an unprojection of a randomly projected point cloud |
672 | 696 | stays the same. |
673 | 697 | """ |
674 | 698 |
|
| 699 | + if device == "cuda": |
| 700 | + device = torch.device("cuda:0") |
| 701 | + else: |
| 702 | + device = torch.device("cpu") |
| 703 | + str2cls = { # noqa |
| 704 | + "OpenGLOrthographicCameras": OpenGLOrthographicCameras, |
| 705 | + "OpenGLPerspectiveCameras": OpenGLPerspectiveCameras, |
| 706 | + "SfMOrthographicCameras": SfMOrthographicCameras, |
| 707 | + "SfMPerspectiveCameras": SfMPerspectiveCameras, |
| 708 | + "FoVOrthographicCameras": FoVOrthographicCameras, |
| 709 | + "FoVPerspectiveCameras": FoVPerspectiveCameras, |
| 710 | + "OrthographicCameras": OrthographicCameras, |
| 711 | + "PerspectiveCameras": PerspectiveCameras, |
| 712 | + "FishEyeCameras": FishEyeCameras, |
| 713 | + } |
| 714 | + |
675 | 715 | def run_cameras(): |
676 | 716 | # init the cameras |
677 | | - cameras = init_random_cameras(cam_type, batch_size) |
| 717 | + cameras = init_random_cameras(str2cls[cam_type], batch_size, device=device) |
678 | 718 | # xyz - the ground truth point cloud |
679 | 719 | xy = torch.randn(num_points, 2) * 2.0 - 1.0 |
680 | 720 | z = torch.randn(num_points, 1) * 3.0 + 1.0 |
|
0 commit comments