Skip to content

Commit fbceff1

Browse files
authored
Merge pull request #1 from classner/classner-patch-#1352
Pulsar unified interface update
2 parents db7c80b + d66c809 commit fbceff1

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

pytorch3d/renderer/points/pulsar/unified.py

+37-27
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _extract_intrinsics( # noqa: C901
206206
"The orthographic camera scale must be ((1.0, 1.0, 1.0),). "
207207
f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}."
208208
)
209-
sensor_width = max_x - min_x
209+
sensor_width = (max_x - min_x) * (self.renderer._renderer.width / self.renderer._renderer.height)
210210
if not sensor_width > 0.0:
211211
raise ValueError(
212212
f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950
@@ -220,24 +220,25 @@ def _extract_intrinsics( # noqa: C901
220220
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
221221
cloud_idx
222222
]
223-
if (
224-
focal_length_conf.numel() == 2
225-
and focal_length_conf[0] * self.renderer._renderer.width
226-
- focal_length_conf[1] * self.renderer._renderer.height
227-
> 1e-5
228-
):
223+
if torch.any(focal_length_conf <= 0.0):
224+
raise ValueError(f"Pulsar requires focal lengths > 0.0. Provided: {focal_length_conf}.")
225+
if cameras.in_ndc():
226+
focal_length_conf *= self.renderer._renderer.height / 2.0
227+
if (focal_length_conf.numel() == 2 and abs(focal_length_conf[0] - focal_length_conf[1]) > 1e-5):
229228
raise ValueError(
230229
"Pulsar only supports a single focal length! "
231230
"Provided: %s." % (str(focal_length_conf))
232231
)
233232
if focal_length_conf.numel() == 2:
234-
sensor_width = 2.0 / focal_length_conf[0]
233+
focal_length_px = focal_length_conf[0]
235234
else:
236235
if focal_length_conf.numel() != 1:
237236
raise ValueError(
238237
"Focal length not parsable: %s." % (str(focal_length_conf))
239238
)
240-
sensor_width = 2.0 / focal_length_conf
239+
focal_length_px = focal_length_conf
240+
focal_length_px /= self.renderer._renderer.width / 2.0
241+
sensor_width = 2.0 / focal_length_px
241242
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
242243
raise ValueError(
243244
"pulsar needs znear and zfar values for "
@@ -248,16 +249,19 @@ def _extract_intrinsics( # noqa: C901
248249
zfar = kwargs["zfar"][cloud_idx]
249250
principal_point_x = (
250251
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
251-
* 0.5
252-
* self.renderer._renderer.width
253252
)
254253
principal_point_y = (
255254
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
256-
* 0.5
257-
* self.renderer._renderer.height
258255
)
256+
if cameras.in_ndc():
257+
principal_point_x *= 0.5 * self.renderer._renderer.width * (self.renderer._renderer.height / self.renderer._renderer.width)
258+
principal_point_y *= -0.5 * self.renderer._renderer.height
259+
else:
260+
principal_point_x = self.renderer._renderer.width / 2.0 - principal_point_x
261+
principal_point_y -= self.renderer._renderer.height / 2.0
259262
else:
260263
if not isinstance(cameras, PerspectiveCameras):
264+
# This currently means FoVPerspectiveCameras.
261265
# Create a virtual focal length that is closer than znear.
262266
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
263267
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
@@ -266,7 +270,10 @@ def _extract_intrinsics( # noqa: C901
266270
afov = kwargs.get("fov", cameras.fov)[cloud_idx]
267271
if kwargs.get("degrees", cameras.degrees):
268272
afov *= math.pi / 180.0
269-
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
273+
aspect_ratio = kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
274+
if aspect_ratio != 1.0:
275+
raise ValueError(f"Pulsar only supports aspect ration 1.0! Provided: {aspect_ratio}.")
276+
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length * (self.renderer._renderer.width / self.renderer._renderer.height)
270277
if not (
271278
kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
272279
- self.renderer._renderer.width / self.renderer._renderer.height
@@ -286,10 +293,13 @@ def _extract_intrinsics( # noqa: C901
286293
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
287294
cloud_idx
288295
]
296+
if torch.any(focal_length_conf <= 0.0):
297+
raise ValueError(f"Pulsar requires focal lengths > 0.0. Provided: {focal_length_conf}.")
298+
if cameras.in_ndc():
299+
focal_length_conf *= self.renderer._renderer.height / 2.0
289300
if (
290301
focal_length_conf.numel() == 2
291-
and focal_length_conf[0] * self.renderer._renderer.width
292-
- focal_length_conf[1] * self.renderer._renderer.height
302+
and abs(focal_length_conf[0] - focal_length_conf[1])
293303
> 1e-5
294304
):
295305
raise ValueError(
@@ -312,6 +322,7 @@ def _extract_intrinsics( # noqa: C901
312322
"Focal length not parsable: %s." % (str(focal_length_conf))
313323
)
314324
focal_length_px = focal_length_conf
325+
focal_length_px /= self.renderer._renderer.width / 2.0
315326
focal_length = torch.tensor(
316327
[
317328
znear - 1e-6,
@@ -322,14 +333,16 @@ def _extract_intrinsics( # noqa: C901
322333
sensor_width = focal_length / focal_length_px * 2.0
323334
principal_point_x = (
324335
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
325-
* 0.5
326-
* self.renderer._renderer.width
327336
)
328337
principal_point_y = (
329338
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
330-
* 0.5
331-
* self.renderer._renderer.height
332339
)
340+
if cameras.in_ndc():
341+
principal_point_x *= 0.5 * self.renderer._renderer.width * (self.renderer._renderer.height / self.renderer._renderer.width)
342+
principal_point_y *= -0.5 * self.renderer._renderer.height
343+
else:
344+
principal_point_x = self.renderer._renderer.width / 2.0 - principal_point_x
345+
principal_point_y -= self.renderer._renderer.height / 2.0
333346
focal_length = _ensure_float_tensor(focal_length, device)
334347
sensor_width = _ensure_float_tensor(sensor_width, device)
335348
principal_point_x = _ensure_float_tensor(principal_point_x, device)
@@ -373,7 +386,7 @@ def _extract_extrinsics(
373386
return cam_pos, cam_rot
374387

375388
def _get_vert_rad(
376-
self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx
389+
self, vert_pos, cam_pos, orthogonal_projection, focal_length, sensor_width, kwargs, cloud_idx
377390
) -> torch.Tensor:
378391
"""
379392
Get point radiuses.
@@ -403,12 +416,8 @@ def _get_vert_rad(
403416
)
404417
else:
405418
point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
406-
vert_rad = raster_rad / focal_length.to(vert_pos.device) * point_dists
407-
if isinstance(self.rasterizer.cameras, PerspectiveCameras):
408-
# NDC normalization happens through adjusted focal length.
409-
pass
410-
else:
411-
vert_rad = vert_rad / 2.0 # NDC normalization.
419+
vert_rad = raster_rad / focal_length.to(vert_pos.device) * point_dists * sensor_width
420+
vert_rad = vert_rad / 2.0 # NDC normalization.
412421
return vert_rad
413422

414423
# point_clouds is not typed to avoid a cyclic dependency.
@@ -503,6 +512,7 @@ def forward(self, point_clouds, **kwargs) -> torch.Tensor:
503512
cam_pos,
504513
orthogonal_projection,
505514
focal_length,
515+
sensor_width,
506516
kwargs,
507517
cloud_idx,
508518
)

0 commit comments

Comments
 (0)