Skip to content

Commit 2a1de3b

Browse files
bottlerfacebook-github-bot
authored andcommitted
move LinearWithRepeat to pytorch3d
Summary: Move this simple layer from the NeRF project into pytorch3d. Reviewed By: shapovalov Differential Revision: D34126972 fbshipit-source-id: a9c6d6c3c1b662c1b844ea5d1b982007d4df83e6
1 parent ef21a6f commit 2a1de3b

File tree

6 files changed

+75
-8
lines changed

6 files changed

+75
-8
lines changed

projects/nerf/nerf/implicit_function.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from typing import Tuple
88

99
import torch
10+
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
1011
from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points
1112

12-
from .linear_with_repeat import LinearWithRepeat
13-
1413

1514
def _xavier_init(linear):
1615
"""

projects/nerf/nerf/linear_with_repeat.py renamed to pytorch3d/common/linear_with_repeat.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
from typing import Tuple
89

910
import torch
1011
import torch.nn.functional as F
12+
from torch.nn import Parameter, init
1113

1214

13-
class LinearWithRepeat(torch.nn.Linear):
15+
class LinearWithRepeat(torch.nn.Module):
1416
"""
1517
if x has shape (..., k, n1)
1618
and y has shape (..., n2)
@@ -50,6 +52,40 @@ class LinearWithRepeat(torch.nn.Linear):
5052
and sent that through the Linear.
5153
"""
5254

55+
def __init__(
56+
self,
57+
in_features: int,
58+
out_features: int,
59+
bias: bool = True,
60+
device=None,
61+
dtype=None,
62+
) -> None:
63+
"""
64+
Copied from torch.nn.Linear.
65+
"""
66+
factory_kwargs = {"device": device, "dtype": dtype}
67+
super().__init__()
68+
self.in_features = in_features
69+
self.out_features = out_features
70+
self.weight = Parameter(
71+
torch.empty((out_features, in_features), **factory_kwargs)
72+
)
73+
if bias:
74+
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
75+
else:
76+
self.register_parameter("bias", None)
77+
self.reset_parameters()
78+
79+
def reset_parameters(self) -> None:
80+
"""
81+
Copied from torch.nn.Linear.
82+
"""
83+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
84+
if self.bias is not None:
85+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
86+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
87+
init.uniform_(self.bias, -bound, bound)
88+
5389
def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
5490
n1 = input[0].shape[-1]
5591
output1 = F.linear(input[0], self.weight[:, :n1], self.bias)

pytorch3d/renderer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@
7373
from .utils import (
7474
TensorProperties,
7575
convert_to_tensors_and_broadcast,
76-
ndc_to_grid_sample_coords,
7776
ndc_grid_sample,
77+
ndc_to_grid_sample_coords,
7878
)
7979

8080

pytorch3d/renderer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import copy
99
import inspect
1010
import warnings
11-
from typing import Any, Optional, Union, Tuple
11+
from typing import Any, Optional, Tuple, Union
1212

1313
import numpy as np
1414
import torch
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from common_testing import TestCaseMixin
11+
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
12+
13+
14+
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
15+
def setUp(self) -> None:
16+
super().setUp()
17+
torch.manual_seed(42)
18+
19+
def test_simple(self):
20+
x = torch.rand(4, 6, 7, 3)
21+
y = torch.rand(4, 6, 4)
22+
23+
linear = torch.nn.Linear(7, 8)
24+
torch.nn.init.xavier_uniform_(linear.weight.data)
25+
linear.bias.data.uniform_()
26+
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
27+
expected = linear.forward(equivalent)
28+
29+
linear_with_repeat = LinearWithRepeat(7, 8)
30+
linear_with_repeat.load_state_dict(linear.state_dict())
31+
actual = linear_with_repeat.forward((x, y))
32+
self.assertClose(actual, expected, rtol=1e-4)

tests/test_rendering_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
from common_testing import TestCaseMixin
1313
from pytorch3d.ops import eyes
1414
from pytorch3d.renderer import (
15-
PerspectiveCameras,
1615
AlphaCompositor,
17-
PointsRenderer,
16+
PerspectiveCameras,
1817
PointsRasterizationSettings,
1918
PointsRasterizer,
19+
PointsRenderer,
2020
)
2121
from pytorch3d.renderer.utils import (
2222
TensorProperties,
23-
ndc_to_grid_sample_coords,
2423
ndc_grid_sample,
24+
ndc_to_grid_sample_coords,
2525
)
2626
from pytorch3d.structures import Pointclouds
2727

0 commit comments

Comments
 (0)