diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 1816e767a..0865547a6 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -3,7 +3,7 @@ import numpy as np -from typing import NamedTuple +from typing import NamedTuple, Sequence import torch # Example functions for blending the top K colors per pixel using the outputs @@ -15,7 +15,7 @@ class BlendParams(NamedTuple): sigma: float = 1e-4 gamma: float = 1e-4 - background_color = (1.0, 1.0, 1.0) + background_color: Sequence = (1.0, 1.0, 1.0) def hard_rgb_blend(colors, fragments) -> torch.Tensor: diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 4e13ad555..dead3ea81 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -3,7 +3,7 @@ import math import numpy as np -from typing import Tuple +from typing import Tuple, Sequence import torch import torch.nn.functional as F @@ -1003,10 +1003,11 @@ def look_at_rotation( def look_at_view_transform( - dist, - elev, - azim, + dist=1.0, + elev=0.0, + azim=0.0, degrees: bool = True, + eye: Sequence = None, at=((0, 0, 0),), # (1, 3) up=((0, 1, 0),), # (1, 3) device="cpu", @@ -1025,10 +1026,12 @@ def look_at_view_transform( reference vector at (1, 0, 0) on the reference plane. dist, elem and azim can be of shape (1), (N). degrees: boolean flag to indicate if the elevation and azimuth - angles are specified in degrees or raidans. + angles are specified in degrees or radians. + eye: the position of the camera(s) in world coordinates. If eye is not + None, it will overide the camera position derived from dist, elev, azim. up: the direction of the x axis in the world coordinate system. at: the position of the object(s) in world coordinates. - up and at can be of shape (1, 3) or (N, 3). + eye, up and at can be of shape (1, 3) or (N, 3). Returns: 2-element tuple containing @@ -1039,11 +1042,19 @@ def look_at_view_transform( References: [0] https://www.scratchapixel.com """ - broadcasted_args = convert_to_tensors_and_broadcast( - dist, elev, azim, at, up, device=device - ) - dist, elev, azim, at, up = broadcasted_args - C = camera_position_from_spherical_angles(dist, elev, azim, device=device) + + if eye is not None: + broadcasted_args = convert_to_tensors_and_broadcast( + eye, at, up, device=device) + eye, at, up = broadcasted_args + C = eye + else: + broadcasted_args = convert_to_tensors_and_broadcast( + dist, elev, azim, at, up, device=device) + dist, elev, azim, at, up = broadcasted_args + C = camera_position_from_spherical_angles( + dist, elev, azim, degrees=degrees, device=device) + R = look_at_rotation(C, at, up, device=device) T = -torch.bmm(R.transpose(1, 2), C[:, :, None])[:, :, 0] return R, T diff --git a/tests/test_cameras.py b/tests/test_cameras.py index fd7f33be8..096c1dfa7 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -39,6 +39,7 @@ camera_position_from_spherical_angles, get_world_to_view_transform, look_at_rotation, + look_at_view_transform, ) from pytorch3d.transforms import Transform3d from pytorch3d.transforms.so3 import so3_exponential_map @@ -117,6 +118,34 @@ def setUp(self) -> None: torch.manual_seed(42) np.random.seed(42) + def test_look_at_view_transform_from_eye_point_tuple(self): + dist = math.sqrt(2) + elev = math.pi / 4 + azim = 0.0 + eye = ((0.0, 1.0, 1.0), ) + # using passed values for dist, elev, azim + R, t = look_at_view_transform(dist, elev, azim, degrees=False) + # using other values for dist, elev, azim - eye overrides + R_eye, t_eye = look_at_view_transform(dist=3, elev=2, azim=1, eye=eye) + # using only eye value + R_eye_only, t_eye_only = look_at_view_transform(eye=eye) + self.assertTrue(torch.allclose(R, R_eye, atol=2e-7)) + self.assertTrue(torch.allclose(t, t_eye, atol=2e-7)) + self.assertTrue(torch.allclose(R, R_eye_only, atol=2e-7)) + self.assertTrue(torch.allclose(t, t_eye_only, atol=2e-7)) + + def test_look_at_view_transform_default_values(self): + dist = 1.0 + elev = 0.0 + azim = 0.0 + # Using passed values for dist, elev, azim + R, t = look_at_view_transform(dist, elev, azim) + # Using default dist=1.0, elev=0.0, azim=0.0 + R_default, t_default = look_at_view_transform() + # test default = passed = expected + self.assertTrue(torch.allclose(R, R_default, atol=2e-7)) + self.assertTrue(torch.allclose(t, t_default, atol=2e-7)) + def test_camera_position_from_angles_python_scalar(self): dist = 2.7 elev = 90.0