10
10
import plotly .graph_objects as go
11
11
import torch
12
12
from plotly .subplots import make_subplots
13
- from pytorch3d .renderer import RayBundle , TexturesVertex , ray_bundle_to_ray_points
13
+ from pytorch3d .renderer import (
14
+ RayBundle ,
15
+ TexturesAtlas ,
16
+ TexturesVertex ,
17
+ ray_bundle_to_ray_points ,
18
+ )
14
19
from pytorch3d .renderer .camera_utils import camera_to_eye_at_up
15
20
from pytorch3d .renderer .cameras import CamerasBase
16
21
from pytorch3d .structures import Meshes , Pointclouds , join_meshes_as_scene
@@ -580,13 +585,19 @@ def _add_mesh_trace(
580
585
mesh = mesh .detach ().cpu ()
581
586
verts = mesh .verts_packed ()
582
587
faces = mesh .faces_packed ()
583
- # If mesh has vertex colors defined as texture , use vertex colors
588
+ # If mesh has vertex colors or face colors , use them
584
589
# for figure, otherwise use plotly's default colors.
585
590
verts_rgb = None
591
+ faces_rgb = None
586
592
if isinstance (mesh .textures , TexturesVertex ):
587
593
verts_rgb = mesh .textures .verts_features_packed ()
588
594
verts_rgb .clamp_ (min = 0.0 , max = 1.0 )
589
595
verts_rgb = torch .tensor (255.0 ) * verts_rgb
596
+ if isinstance (mesh .textures , TexturesAtlas ):
597
+ atlas = mesh .textures .atlas_packed ()
598
+ # If K==1
599
+ if atlas .shape [1 ] == 1 and atlas .shape [3 ] == 3 :
600
+ faces_rgb = atlas [:, 0 , 0 ]
590
601
591
602
# Reposition the unused vertices to be "inside" the object
592
603
# (i.e. they won't be visible in the plot).
@@ -602,6 +613,7 @@ def _add_mesh_trace(
602
613
y = verts [:, 1 ],
603
614
z = verts [:, 2 ],
604
615
vertexcolor = verts_rgb ,
616
+ facecolor = faces_rgb ,
605
617
i = faces [:, 0 ],
606
618
j = faces [:, 1 ],
607
619
k = faces [:, 2 ],
0 commit comments