Skip to content

Commit 9446d91

Browse files
EmGarrfacebook-github-bot
authored andcommitted
Avoid to keep in memory lengths and bins for ImplicitronRayBundle
Summary: Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes: - init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor - lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory. Reviewed By: shapovalov Differential Revision: D46686094 fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
1 parent 3d011a9 commit 9446d91

File tree

5 files changed

+102
-60
lines changed

5 files changed

+102
-60
lines changed

pytorch3d/implicitron/models/renderer/base.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from __future__ import annotations
88

9-
import dataclasses
10-
119
from abc import ABC, abstractmethod
1210
from dataclasses import dataclass, field
1311
from enum import Enum
@@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
2927
FULL_GRID = "full_grid"
3028

3129

32-
@dataclasses.dataclass
3330
class ImplicitronRayBundle:
3431
"""
3532
Parametrizes points along projection rays by storing ray `origins`,
@@ -69,53 +66,58 @@ class ImplicitronRayBundle:
6966
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
7067
pixel_radii_2d: An optional tensor of shape `(..., 1)`
7168
base radii of the conical frustums.
69+
70+
Raises:
71+
ValueError: If either bins or lengths are not provided.
72+
ValueError: If bins is provided and the last dim is inferior or equal to 1.
7273
"""
7374

74-
origins: torch.Tensor
75-
directions: torch.Tensor
76-
lengths: torch.Tensor
77-
xys: torch.Tensor
78-
camera_ids: Optional[torch.LongTensor] = None
79-
camera_counts: Optional[torch.LongTensor] = None
80-
bins: Optional[torch.Tensor] = None
81-
pixel_radii_2d: Optional[torch.Tensor] = None
82-
83-
@classmethod
84-
def from_bins(
85-
cls,
75+
def __init__(
76+
self,
8677
origins: torch.Tensor,
8778
directions: torch.Tensor,
88-
bins: torch.Tensor,
79+
lengths: Optional[torch.Tensor],
8980
xys: torch.Tensor,
90-
**kwargs,
91-
) -> "ImplicitronRayBundle":
92-
"""
93-
Creates a new instance from bins instead of lengths.
94-
95-
Attributes:
96-
origins: A tensor of shape `(..., 3)` denoting the
97-
origins of the sampling rays in world coords.
98-
directions: A tensor of shape `(..., 3)` containing the direction
99-
vectors of sampling rays in world coords. They don't have to be normalized;
100-
they define unit vectors in the respective 1D coordinate systems; see
101-
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
102-
bins: A tensor of shape `(..., num_points_per_ray + 1)`
103-
containing the bins at which the rays are sampled. In this case
104-
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
105-
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
106-
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
107-
Returns:
108-
An instance of ImplicitronRayBundle.
109-
"""
110-
111-
if bins.shape[-1] <= 1:
81+
camera_ids: Optional[torch.LongTensor] = None,
82+
camera_counts: Optional[torch.LongTensor] = None,
83+
bins: Optional[torch.Tensor] = None,
84+
pixel_radii_2d: Optional[torch.Tensor] = None,
85+
):
86+
if bins is not None and bins.shape[-1] <= 1:
11287
raise ValueError(
11388
"The last dim of bins must be at least superior or equal to 2."
11489
)
115-
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
116-
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
11790

118-
return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
91+
if bins is None and lengths is None:
92+
raise ValueError(
93+
"Please set either bins or lengths to initialize an ImplicitronRayBundle."
94+
)
95+
96+
self.origins = origins
97+
self.directions = directions
98+
self._lengths = lengths if bins is None else None
99+
self.xys = xys
100+
self.bins = bins
101+
self.pixel_radii_2d = pixel_radii_2d
102+
self.camera_ids = camera_ids
103+
self.camera_counts = camera_counts
104+
105+
@property
106+
def lengths(self) -> torch.Tensor:
107+
if self.bins is not None:
108+
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
109+
# pyre-ignore
110+
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
111+
return self._lengths
112+
113+
@lengths.setter
114+
def lengths(self, value):
115+
if self.bins is not None:
116+
raise ValueError(
117+
"If the bins attribute is not None you cannot set the lengths attribute."
118+
)
119+
else:
120+
self._lengths = value
119121

120122
def is_packed(self) -> bool:
121123
"""

pytorch3d/implicitron/models/renderer/lstm_renderer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import dataclasses
7+
import copy
88
import logging
99
from typing import List, Optional, Tuple
1010

@@ -102,12 +102,11 @@ def forward(
102102
)
103103

104104
# jitter the initial depths
105-
ray_bundle_t = dataclasses.replace(
106-
ray_bundle,
107-
lengths=(
108-
ray_bundle.lengths
109-
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
110-
),
105+
106+
ray_bundle_t = copy.copy(ray_bundle)
107+
ray_bundle_t.lengths = (
108+
ray_bundle.lengths
109+
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
111110
)
112111

113112
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]

pytorch3d/implicitron/models/renderer/ray_point_refiner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
8+
79
import torch
810
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
911
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
@@ -106,14 +108,13 @@ def forward(
106108
z_vals = z_samples
107109
# Resort by depth.
108110
z_vals, _ = torch.sort(z_vals, dim=-1)
109-
110-
kwargs_ray = dict(vars(input_ray_bundle))
111+
ray_bundle = copy.copy(input_ray_bundle)
111112
if input_ray_bundle.bins is None:
112-
kwargs_ray["lengths"] = z_vals
113-
return ImplicitronRayBundle(**kwargs_ray)
114-
kwargs_ray["bins"] = z_vals
115-
del kwargs_ray["lengths"]
116-
return ImplicitronRayBundle.from_bins(**kwargs_ray)
113+
ray_bundle.lengths = z_vals
114+
else:
115+
ray_bundle.bins = z_vals
116+
117+
return ray_bundle
117118

118119

119120
def apply_blurpool_on_weights(weights) -> torch.Tensor:

pytorch3d/implicitron/models/renderer/ray_sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,12 @@ def forward(
236236
elif self.cast_ray_bundle_as_cone:
237237
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
238238
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
239-
return ImplicitronRayBundle.from_bins(
239+
return ImplicitronRayBundle(
240240
directions=ray_bundle.directions,
241241
origins=ray_bundle.origins,
242-
bins=ray_bundle.lengths,
242+
lengths=None,
243243
xys=ray_bundle.xys,
244+
bins=ray_bundle.lengths,
244245
pixel_radii_2d=pixel_radii_2d,
245246
)
246247

tests/implicitron/test_models_renderer_base.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,62 @@
2525
class TestRendererBase(TestCaseMixin, unittest.TestCase):
2626
def test_implicitron_from_bins(self) -> None:
2727
bins = torch.randn(2, 3, 4, 5)
28-
ray_bundle = ImplicitronRayBundle.from_bins(
28+
ray_bundle = ImplicitronRayBundle(
2929
origins=None,
3030
directions=None,
31+
lengths=None,
3132
xys=None,
3233
bins=bins,
3334
)
3435
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
3536
self.assertClose(ray_bundle.bins, bins)
3637

38+
def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
39+
self,
40+
) -> None:
41+
with self.assertRaises(ValueError) as context:
42+
ray_bundle = ImplicitronRayBundle(
43+
origins=torch.rand(2, 3, 4, 3),
44+
directions=torch.rand(2, 3, 4, 3),
45+
lengths=None,
46+
xys=torch.rand(2, 3, 4, 2),
47+
bins=torch.rand(2, 3, 4, 1),
48+
)
49+
ray_bundle.lengths = torch.empty(2)
50+
self.assertEqual(
51+
str(context.exception),
52+
"If the bins attribute is not None you cannot set the lengths attribute.",
53+
)
54+
3755
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
38-
with self.assertRaises(ValueError):
39-
ImplicitronRayBundle.from_bins(
56+
with self.assertRaises(ValueError) as context:
57+
ImplicitronRayBundle(
4058
origins=torch.rand(2, 3, 4, 3),
4159
directions=torch.rand(2, 3, 4, 3),
60+
lengths=None,
4261
xys=torch.rand(2, 3, 4, 2),
4362
bins=torch.rand(2, 3, 4, 1),
4463
)
64+
self.assertEqual(
65+
str(context.exception),
66+
"The last dim of bins must be at least superior or equal to 2.",
67+
)
68+
69+
def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
70+
self,
71+
) -> None:
72+
with self.assertRaises(ValueError) as context:
73+
ImplicitronRayBundle(
74+
origins=torch.rand(2, 3, 4, 3),
75+
directions=torch.rand(2, 3, 4, 3),
76+
lengths=None,
77+
xys=torch.rand(2, 3, 4, 2),
78+
bins=None,
79+
)
80+
self.assertEqual(
81+
str(context.exception),
82+
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
83+
)
4584

4685
def test_conical_frustum_to_gaussian(self) -> None:
4786
origins = torch.zeros(3, 3, 3)

0 commit comments

Comments
 (0)