Skip to content

Commit ad8907d

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
ImplicitronRayBundle
Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated. Reviewed By: bottler, kjchalup Differential Revision: D39262999 fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
1 parent 6ae863f commit ad8907d

18 files changed

+259
-100
lines changed

docs/tutorials/implicitron_volumes.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,9 @@
145145
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
146146
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
147147
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
148-
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
148+
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n",
149149
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
150150
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
151-
"from pytorch3d.renderer import RayBundle\n",
152151
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
153152
"from pytorch3d.structures import Volumes\n",
154153
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
@@ -393,7 +392,7 @@
393392
"\n",
394393
" def forward(\n",
395394
" self,\n",
396-
" ray_bundle: RayBundle,\n",
395+
" ray_bundle: ImplicitronRayBundle,\n",
397396
" fun_viewpool=None,\n",
398397
" global_code=None,\n",
399398
" ):\n",

pytorch3d/implicitron/models/generic_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RegularizationMetricsBase,
2323
ViewMetricsBase,
2424
)
25+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
2526
from pytorch3d.implicitron.tools import image_utils, vis_utils
2627
from pytorch3d.implicitron.tools.config import (
2728
expand_args_fields,
@@ -30,7 +31,8 @@
3031
)
3132
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
3233
from pytorch3d.implicitron.tools.utils import cat_dataclass
33-
from pytorch3d.renderer import RayBundle, utils as rend_utils
34+
from pytorch3d.renderer import utils as rend_utils
35+
3436
from pytorch3d.renderer.cameras import CamerasBase
3537
from visdom import Visdom
3638

@@ -387,7 +389,7 @@ def safe_slice_targets(
387389
)
388390

389391
# (1) Sample rendering rays with the ray sampler.
390-
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
392+
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
391393
target_cameras,
392394
evaluation_mode,
393395
mask=mask_crop[:n_targets]
@@ -568,14 +570,14 @@ def visualize(
568570
def _render(
569571
self,
570572
*,
571-
ray_bundle: RayBundle,
573+
ray_bundle: ImplicitronRayBundle,
572574
chunked_inputs: Dict[str, torch.Tensor],
573575
sampling_mode: RenderSamplingMode,
574576
**kwargs,
575577
) -> RendererOutput:
576578
"""
577579
Args:
578-
ray_bundle: A `RayBundle` object containing the parametrizations of the
580+
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
579581
sampled rendering rays.
580582
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
581583
SignedDistanceFunctionRenderer requires "object_mask", shape
@@ -899,7 +901,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
899901

900902
def _chunk_generator(
901903
chunk_size: int,
902-
ray_bundle: RayBundle,
904+
ray_bundle: ImplicitronRayBundle,
903905
chunked_inputs: Dict[str, torch.Tensor],
904906
tqdm_trigger_threshold: int,
905907
*args,
@@ -932,7 +934,7 @@ def _chunk_generator(
932934

933935
for start_idx in iter:
934936
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
935-
ray_bundle_chunk = RayBundle(
937+
ray_bundle_chunk = ImplicitronRayBundle(
936938
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
937939
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
938940
:, start_idx:end_idx

pytorch3d/implicitron/models/implicit_function/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from abc import ABC, abstractmethod
88
from typing import Optional
99

10+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
11+
1012
from pytorch3d.implicitron.tools.config import ReplaceableBase
1113
from pytorch3d.renderer.cameras import CamerasBase
12-
from pytorch3d.renderer.implicit import RayBundle
1314

1415

1516
class ImplicitFunctionBase(ABC, ReplaceableBase):
@@ -20,7 +21,7 @@ def __init__(self):
2021
def forward(
2122
self,
2223
*,
23-
ray_bundle: RayBundle,
24+
ray_bundle: ImplicitronRayBundle,
2425
fun_viewpool=None,
2526
camera: Optional[CamerasBase] = None,
2627
global_code=None,

pytorch3d/implicitron/models/implicit_function/idr_feature_field.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from typing import Optional, Tuple
77

88
import torch
9+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
910
from pytorch3d.implicitron.tools.config import registry
10-
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
11+
from pytorch3d.renderer.implicit import HarmonicEmbedding
12+
1113
from torch import nn
1214

1315
from .base import ImplicitFunctionBase
@@ -127,7 +129,7 @@ def __post_init__(self):
127129
def forward(
128130
self,
129131
*,
130-
ray_bundle: Optional[RayBundle] = None,
132+
ray_bundle: Optional[ImplicitronRayBundle] = None,
131133
rays_points_world: Optional[torch.Tensor] = None,
132134
fun_viewpool=None,
133135
global_code=None,

pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
import torch
1111
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
12+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
1213
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
13-
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
14+
from pytorch3d.renderer import ray_bundle_to_ray_points
1415
from pytorch3d.renderer.cameras import CamerasBase
1516
from pytorch3d.renderer.implicit import HarmonicEmbedding
1617

@@ -130,7 +131,7 @@ def allows_multiple_passes() -> bool:
130131
def forward(
131132
self,
132133
*,
133-
ray_bundle: RayBundle,
134+
ray_bundle: ImplicitronRayBundle,
134135
fun_viewpool=None,
135136
camera: Optional[CamerasBase] = None,
136137
global_code=None,
@@ -144,7 +145,7 @@ def forward(
144145
RGB color and opacity respectively.
145146
146147
Args:
147-
ray_bundle: A RayBundle object containing the following variables:
148+
ray_bundle: An ImplicitronRayBundle object containing the following variables:
148149
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
149150
origins of the sampling rays in world coords.
150151
directions: A tensor of shape `(minibatch, ..., 3)`
@@ -165,11 +166,12 @@ def forward(
165166
"""
166167
# We first convert the ray parametrizations to world
167168
# coordinates with `ray_bundle_to_ray_points`.
169+
# pyre-ignore[6]
168170
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
169171
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
170172

171173
embeds = create_embeddings_for_implicit_function(
172-
xyz_world=ray_bundle_to_ray_points(ray_bundle),
174+
xyz_world=rays_points_world,
173175
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
174176
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
175177
xyz_embedding_function=self.harmonic_embedding_xyz

pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import torch
77
from omegaconf import DictConfig
88
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
9+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
910
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
1011
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
11-
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
12+
from pytorch3d.renderer import ray_bundle_to_ray_points
1213
from pytorch3d.renderer.cameras import CamerasBase
1314
from pytorch3d.renderer.implicit import HarmonicEmbedding
1415

@@ -68,15 +69,15 @@ def __post_init__(self):
6869

6970
def forward(
7071
self,
71-
ray_bundle: RayBundle,
72+
ray_bundle: ImplicitronRayBundle,
7273
fun_viewpool=None,
7374
camera: Optional[CamerasBase] = None,
7475
global_code=None,
7576
**kwargs,
7677
):
7778
"""
7879
Args:
79-
ray_bundle: A RayBundle object containing the following variables:
80+
ray_bundle: An ImplicitronRayBundle object containing the following variables:
8081
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
8182
origins of the sampling rays in world coords.
8283
directions: A tensor of shape `(minibatch, ..., 3)`
@@ -96,10 +97,11 @@ def forward(
9697
"""
9798
# We first convert the ray parametrizations to world
9899
# coordinates with `ray_bundle_to_ray_points`.
100+
# pyre-ignore[6]
99101
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
100102

101103
embeds = create_embeddings_for_implicit_function(
102-
xyz_world=ray_bundle_to_ray_points(ray_bundle),
104+
xyz_world=rays_points_world,
103105
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
104106
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
105107
xyz_embedding_function=self._harmonic_embedding,
@@ -175,15 +177,15 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
175177
def forward(
176178
self,
177179
raymarch_features: torch.Tensor,
178-
ray_bundle: RayBundle,
180+
ray_bundle: ImplicitronRayBundle,
179181
camera: Optional[CamerasBase] = None,
180182
**kwargs,
181183
):
182184
"""
183185
Args:
184186
raymarch_features: Features from the raymarching network of shape
185187
`(minibatch, ..., self.in_features)`
186-
ray_bundle: A RayBundle object containing the following variables:
188+
ray_bundle: An ImplicitronRayBundle object containing the following variables:
187189
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
188190
origins of the sampling rays in world coords.
189191
directions: A tensor of shape `(minibatch, ..., 3)`
@@ -297,7 +299,7 @@ def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]
297299

298300
def forward(
299301
self,
300-
ray_bundle: RayBundle,
302+
ray_bundle: ImplicitronRayBundle,
301303
fun_viewpool=None,
302304
camera: Optional[CamerasBase] = None,
303305
global_code=None,
@@ -350,7 +352,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
350352
def forward(
351353
self,
352354
*,
353-
ray_bundle: RayBundle,
355+
ray_bundle: ImplicitronRayBundle,
354356
fun_viewpool=None,
355357
camera: Optional[CamerasBase] = None,
356358
global_code=None,
@@ -410,7 +412,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
410412
def forward(
411413
self,
412414
*,
413-
ray_bundle: RayBundle,
415+
ray_bundle: ImplicitronRayBundle,
414416
fun_viewpool=None,
415417
camera: Optional[CamerasBase] = None,
416418
global_code=None,

pytorch3d/implicitron/models/implicit_function/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
import torch.nn.functional as F
1212
from pytorch3d.common.compat import prod
13+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
1314
from pytorch3d.renderer import ray_bundle_to_ray_points
1415
from pytorch3d.renderer.cameras import CamerasBase
15-
from pytorch3d.renderer.implicit import RayBundle
1616

1717

1818
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
@@ -190,15 +190,15 @@ def interpolate_volume(
190190

191191

192192
def get_rays_points_world(
193-
ray_bundle: Optional[RayBundle] = None,
193+
ray_bundle: Optional[ImplicitronRayBundle] = None,
194194
rays_points_world: Optional[torch.Tensor] = None,
195195
) -> torch.Tensor:
196196
"""
197197
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
198198
and raises error if both are defined.
199199
200200
Args:
201-
ray_bundle: A RayBundle object or None
201+
ray_bundle: An ImplicitronRayBundle object or None
202202
rays_points_world: A torch.Tensor representing ray points converted to
203203
world coordinates
204204
Returns:
@@ -213,5 +213,6 @@ def get_rays_points_world(
213213
if rays_points_world is not None:
214214
return rays_points_world
215215
if ray_bundle is not None:
216+
# pyre-ignore[6]
216217
return ray_bundle_to_ray_points(ray_bundle)
217218
raise ValueError("ray_bundle and rays_points_world cannot both be None")

pytorch3d/implicitron/models/renderer/base.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import annotations
88

9+
import dataclasses
10+
911
from abc import ABC, abstractmethod
1012
from dataclasses import dataclass, field
1113
from enum import Enum
@@ -25,6 +27,38 @@ class RenderSamplingMode(Enum):
2527
FULL_GRID = "full_grid"
2628

2729

30+
@dataclasses.dataclass
31+
class ImplicitronRayBundle:
32+
"""
33+
Parametrizes points along projection rays by storing ray `origins`,
34+
`directions` vectors and `lengths` at which the ray-points are sampled.
35+
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
36+
Note that `directions` don't have to be normalized; they define unit vectors
37+
in the respective 1D coordinate systems; see documentation for
38+
:func:`ray_bundle_to_ray_points` for the conversion formula.
39+
40+
camera_ids: A tensor of shape (N, ) which indicates which camera
41+
was used to sample the rays. `N` is the number of different
42+
sampled cameras.
43+
camera_counts: A tensor of shape (N, ) which how many times the
44+
coresponding camera in `camera_ids` was sampled.
45+
`sum(camera_counts)==minibatch`
46+
"""
47+
48+
origins: torch.Tensor
49+
directions: torch.Tensor
50+
lengths: torch.Tensor
51+
xys: torch.Tensor
52+
camera_ids: Optional[torch.Tensor] = None
53+
camera_counts: Optional[torch.Tensor] = None
54+
55+
def is_packed(self) -> bool:
56+
"""
57+
Returns whether the ImplicitronRayBundle carries data in packed state
58+
"""
59+
return self.camera_ids is not None and self.camera_counts is not None
60+
61+
2862
@dataclass
2963
class RendererOutput:
3064
"""
@@ -85,7 +119,7 @@ def requires_object_mask(self) -> bool:
85119
@abstractmethod
86120
def forward(
87121
self,
88-
ray_bundle,
122+
ray_bundle: ImplicitronRayBundle,
89123
implicit_functions: List[ImplicitFunctionWrapper],
90124
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
91125
**kwargs,
@@ -95,7 +129,7 @@ def forward(
95129
that returns an instance of RendererOutput.
96130
97131
Args:
98-
ray_bundle: A RayBundle object containing the following variables:
132+
ray_bundle: An ImplicitronRayBundle object containing the following variables:
99133
origins: A tensor of shape (minibatch, ..., 3) denoting
100134
the origins of the rendering rays.
101135
directions: A tensor of shape (minibatch, ..., 3)
@@ -108,6 +142,12 @@ def forward(
108142
xys: A tensor of shape
109143
(minibatch, ..., 2) containing the
110144
xy locations of each ray's pixel in the NDC screen space.
145+
camera_ids: A tensor of shape (N, ) which indicates which camera
146+
was used to sample the rays. `N` is the number of different
147+
sampled cameras.
148+
camera_counts: A tensor of shape (N, ) which how many times the
149+
coresponding camera in `camera_ids` was sampled.
150+
`sum(camera_counts)==minibatch`
111151
implicit_functions: List of ImplicitFunctionWrappers which define the
112152
implicit function methods to be used. Most Renderers only allow
113153
a single implicit function. Currently, only the

0 commit comments

Comments
 (0)