Skip to content

Commit 7f2f95f

Browse files
gkioxarifacebook-github-bot
authored andcommitted
detach for meshes, pointclouds, textures
Summary: Add `detach` for Meshes, Pointclouds, Textures Reviewed By: nikhilaravi Differential Revision: D23070418 fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
1 parent 5852b74 commit 7f2f95f

File tree

6 files changed

+283
-8
lines changed

6 files changed

+283
-8
lines changed

pytorch3d/renderer/mesh/textures.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ def clone(self):
242242
"""
243243
raise NotImplementedError()
244244

245+
def detach(self):
246+
"""
247+
Each texture class should implement a method
248+
to detach all necessary internal tensors.
249+
"""
250+
raise NotImplementedError()
251+
245252
def __getitem__(self, index):
246253
"""
247254
Each texture class should implement a method
@@ -388,6 +395,8 @@ def __repr__(self):
388395

389396
def clone(self):
390397
tex = self.__class__(atlas=self.atlas_padded().clone())
398+
if self._atlas_list is not None:
399+
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
391400
num_faces = (
392401
self._num_faces_per_mesh.clone()
393402
if torch.is_tensor(self._num_faces_per_mesh)
@@ -397,6 +406,19 @@ def clone(self):
397406
tex._num_faces_per_mesh = num_faces
398407
return tex
399408

409+
def detach(self):
410+
tex = self.__class__(atlas=self.atlas_padded().detach())
411+
if self._atlas_list is not None:
412+
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
413+
num_faces = (
414+
self._num_faces_per_mesh.detach()
415+
if torch.is_tensor(self._num_faces_per_mesh)
416+
else self._num_faces_per_mesh
417+
)
418+
tex.valid = self.valid.detach()
419+
tex._num_faces_per_mesh = num_faces
420+
return tex
421+
400422
def __getitem__(self, index):
401423
props = ["atlas_list", "_num_faces_per_mesh"]
402424
new_props = self._getitem(index, props=props)
@@ -656,6 +678,12 @@ def clone(self):
656678
self.faces_uvs_padded().clone(),
657679
self.verts_uvs_padded().clone(),
658680
)
681+
if self._maps_list is not None:
682+
tex._maps_list = [m.clone() for m in self._maps_list]
683+
if self._verts_uvs_list is not None:
684+
tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
685+
if self._faces_uvs_list is not None:
686+
tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
659687
num_faces = (
660688
self._num_faces_per_mesh.clone()
661689
if torch.is_tensor(self._num_faces_per_mesh)
@@ -665,6 +693,27 @@ def clone(self):
665693
tex.valid = self.valid.clone()
666694
return tex
667695

696+
def detach(self):
697+
tex = self.__class__(
698+
self.maps_padded().detach(),
699+
self.faces_uvs_padded().detach(),
700+
self.verts_uvs_padded().detach(),
701+
)
702+
if self._maps_list is not None:
703+
tex._maps_list = [m.detach() for m in self._maps_list]
704+
if self._verts_uvs_list is not None:
705+
tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]
706+
if self._faces_uvs_list is not None:
707+
tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list]
708+
num_faces = (
709+
self._num_faces_per_mesh.detach()
710+
if torch.is_tensor(self._num_faces_per_mesh)
711+
else self._num_faces_per_mesh
712+
)
713+
tex._num_faces_per_mesh = num_faces
714+
tex.valid = self.valid.detach()
715+
return tex
716+
668717
def __getitem__(self, index):
669718
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
670719
new_props = self._getitem(index, props)
@@ -892,8 +941,8 @@ def __init__(
892941
has a D dimensional feature vector.
893942
894943
Args:
895-
verts_features: (N, V, D) tensor giving a feature vector with
896-
artbitrary dimensions for each vertex.
944+
verts_features: list of (Vi, D) or (N, V, D) tensor giving a feature
945+
vector with artbitrary dimensions for each vertex.
897946
"""
898947
if isinstance(verts_features, (tuple, list)):
899948
correct_shape = all(
@@ -948,15 +997,28 @@ def clone(self):
948997
tex = self.__class__(self.verts_features_padded().clone())
949998
if self._verts_features_list is not None:
950999
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
951-
num_faces = (
1000+
num_verts = (
9521001
self._num_verts_per_mesh.clone()
9531002
if torch.is_tensor(self._num_verts_per_mesh)
9541003
else self._num_verts_per_mesh
9551004
)
956-
tex._num_verts_per_mesh = num_faces
1005+
tex._num_verts_per_mesh = num_verts
9571006
tex.valid = self.valid.clone()
9581007
return tex
9591008

1009+
def detach(self):
1010+
tex = self.__class__(self.verts_features_padded().detach())
1011+
if self._verts_features_list is not None:
1012+
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
1013+
num_verts = (
1014+
self._num_verts_per_mesh.detach()
1015+
if torch.is_tensor(self._num_verts_per_mesh)
1016+
else self._num_verts_per_mesh
1017+
)
1018+
tex._num_verts_per_mesh = num_verts
1019+
tex.valid = self.valid.detach()
1020+
return tex
1021+
9601022
def __getitem__(self, index):
9611023
props = ["verts_features_list", "_num_verts_per_mesh"]
9621024
new_props = self._getitem(index, props)

pytorch3d/structures/meshes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,28 @@ def clone(self):
11381138
other.textures = self.textures.clone()
11391139
return other
11401140

1141+
def detach(self):
1142+
"""
1143+
Detach Meshes object. All internal tensors are detached individually.
1144+
1145+
Returns:
1146+
new Meshes object.
1147+
"""
1148+
verts_list = self.verts_list()
1149+
faces_list = self.faces_list()
1150+
new_verts_list = [v.detach() for v in verts_list]
1151+
new_faces_list = [f.detach() for f in faces_list]
1152+
other = self.__class__(verts=new_verts_list, faces=new_faces_list)
1153+
for k in self._INTERNAL_TENSORS:
1154+
v = getattr(self, k)
1155+
if torch.is_tensor(v):
1156+
setattr(other, k, v.detach())
1157+
1158+
# Textures is not a tensor but has a detach method
1159+
if self.textures is not None:
1160+
other.textures = self.textures.detach()
1161+
return other
1162+
11411163
def to(self, device, copy: bool = False):
11421164
"""
11431165
Match functionality of torch.Tensor.to()

pytorch3d/structures/pointclouds.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,42 @@ def clone(self):
655655
setattr(other, k, v.clone())
656656
return other
657657

658+
def detach(self):
659+
"""
660+
Detach Pointclouds object. All internal tensors are detached
661+
individually.
662+
663+
Returns:
664+
new Pointclouds object.
665+
"""
666+
# instantiate new pointcloud with the representation which is not None
667+
# (either list or tensor) to save compute.
668+
new_points, new_normals, new_features = None, None, None
669+
if self._points_list is not None:
670+
new_points = [v.detach() for v in self.points_list()]
671+
normals_list = self.normals_list()
672+
features_list = self.features_list()
673+
if normals_list is not None:
674+
new_normals = [n.detach() for n in normals_list]
675+
if features_list is not None:
676+
new_features = [f.detach() for f in features_list]
677+
elif self._points_padded is not None:
678+
new_points = self.points_padded().detach()
679+
normals_padded = self.normals_padded()
680+
features_padded = self.features_padded()
681+
if normals_padded is not None:
682+
new_normals = self.normals_padded().detach()
683+
if features_padded is not None:
684+
new_features = self.features_padded().detach()
685+
other = self.__class__(
686+
points=new_points, normals=new_normals, features=new_features
687+
)
688+
for k in self._INTERNAL_TENSORS:
689+
v = getattr(self, k)
690+
if torch.is_tensor(v):
691+
setattr(other, k, v.detach())
692+
return other
693+
658694
def to(self, device, copy: bool = False):
659695
"""
660696
Match functionality of torch.Tensor.to()

tests/test_meshes.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def init_mesh(
2020
max_f: int = 300,
2121
lists_to_tensors: bool = False,
2222
device: str = "cpu",
23+
requires_grad: bool = False,
2324
):
2425
"""
2526
Function to generate a Meshes object of N meshes with
@@ -57,7 +58,12 @@ def init_mesh(
5758

5859
# Generate the actual vertices and faces.
5960
for i in range(num_meshes):
60-
verts = torch.rand((v[i], 3), dtype=torch.float32, device=device)
61+
verts = torch.rand(
62+
(v[i], 3),
63+
dtype=torch.float32,
64+
device=device,
65+
requires_grad=requires_grad,
66+
)
6167
faces = torch.randint(
6268
v[i], size=(f[i], 3), dtype=torch.int64, device=device
6369
)
@@ -353,6 +359,26 @@ def test_clone(self):
353359
self.assertSeparate(new_mesh.faces_padded(), mesh.faces_padded())
354360
self.assertSeparate(new_mesh.edges_packed(), mesh.edges_packed())
355361

362+
def test_detach(self):
363+
N = 5
364+
mesh = TestMeshes.init_mesh(N, 10, 100, requires_grad=True)
365+
for force in [0, 1]:
366+
if force:
367+
# force mesh to have computed attributes
368+
mesh.verts_packed()
369+
mesh.edges_packed()
370+
mesh.verts_padded()
371+
372+
new_mesh = mesh.detach()
373+
374+
self.assertFalse(new_mesh.verts_packed().requires_grad)
375+
self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
376+
self.assertTrue(new_mesh.verts_padded().requires_grad == False)
377+
self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
378+
for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
379+
self.assertTrue(newv.requires_grad == False)
380+
self.assertClose(newv, v)
381+
356382
def test_laplacian_packed(self):
357383
def naive_laplacian_packed(meshes):
358384
verts_packed = meshes.verts_packed()

tests/test_pointclouds.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def init_cloud(
2424
with_normals: bool = True,
2525
with_features: bool = True,
2626
min_points: int = 0,
27+
requires_grad: bool = False,
2728
):
2829
"""
2930
Function to generate a Pointclouds object of N meshes with
@@ -49,16 +50,31 @@ def init_cloud(
4950
p.fill_(p[0])
5051

5152
points_list = [
52-
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
53+
torch.rand(
54+
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
55+
)
56+
for i in p
5357
]
5458
normals_list, features_list = None, None
5559
if with_normals:
5660
normals_list = [
57-
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
61+
torch.rand(
62+
(i, 3),
63+
device=device,
64+
dtype=torch.float32,
65+
requires_grad=requires_grad,
66+
)
67+
for i in p
5868
]
5969
if with_features:
6070
features_list = [
61-
torch.rand((i, channels), device=device, dtype=torch.float32) for i in p
71+
torch.rand(
72+
(i, channels),
73+
device=device,
74+
dtype=torch.float32,
75+
requires_grad=requires_grad,
76+
)
77+
for i in p
6278
]
6379

6480
if lists_to_tensors:
@@ -382,6 +398,39 @@ def test_clone_tensor(self):
382398

383399
self.assertCloudsEqual(clouds, new_clouds)
384400

401+
def test_detach(self):
402+
N = 5
403+
for lists_to_tensors in (True, False):
404+
clouds = self.init_cloud(
405+
N, 100, 5, lists_to_tensors=lists_to_tensors, requires_grad=True
406+
)
407+
for force in (False, True):
408+
if force:
409+
clouds.points_packed()
410+
411+
new_clouds = clouds.detach()
412+
413+
for cloud in new_clouds.points_list():
414+
self.assertTrue(cloud.requires_grad == False)
415+
for normal in new_clouds.normals_list():
416+
self.assertTrue(normal.requires_grad == False)
417+
for feats in new_clouds.features_list():
418+
self.assertTrue(feats.requires_grad == False)
419+
420+
for attrib in [
421+
"points_packed",
422+
"normals_packed",
423+
"features_packed",
424+
"points_padded",
425+
"normals_padded",
426+
"features_padded",
427+
]:
428+
self.assertTrue(
429+
getattr(new_clouds, attrib)().requires_grad == False
430+
)
431+
432+
self.assertCloudsEqual(clouds, new_clouds)
433+
385434
def assertCloudsEqual(self, cloud1, cloud2):
386435
N = len(cloud1)
387436
self.assertEqual(N, len(cloud2))

0 commit comments

Comments
 (0)