Skip to content

Incorrect rendering of object compared to Ground Truth image on LineMod Dataset #934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
manasburagohain opened this issue Nov 13, 2021 · 10 comments
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@manasburagohain
Copy link

manasburagohain commented Nov 13, 2021

🐛 Bugs / Unexpected behaviors

Given a ground truth image and a ground truth pose, I attempt to render the object with the same pose to align with GT image.

Initially setting the camera pose to GT pose results in following rendering:

test_no_change

However since Pytorch3D follows an altered axes convention, I attempt use the fix suggested #522 by negating f to -f, which results in the following render:

test_negate_f

Additionally , I attempt to use the fix mentioned in #294 by rotating about z by 180 degrees which results in the following render:

test_rotate_z

As we can see neither fixes result in the rendered object exactly aligning with the GT pose. This has been over multiple GT poses and images.

Instructions To Reproduce the Issue:

Please include the following (depending on what the issue is):

  1. Any changes you made (git diff) or code you wrote
import os
import numpy as np
import cv2
import torch
from pytorch3d.structures import Meshes
from pytorch3d.io import IO
from pytorch3d.renderer import (
    look_at_view_transform,
    look_at_rotation,
    OpenGLPerspectiveCameras,
    PerspectiveCameras, 
    PointLights, 
    DirectionalLights,
    AmbientLights,  
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    TexturesVertex
)

#Params
K = np.array([[572.4114, 0.,         325.2611],
            [0.,        573.57043,  242.04899],
            [0.,        0.,         1.]])

f_x, f_y = K[0,0], K[1,1]
p_x, p_y = K[0,2], K[1,2]
h = 480
w = 640

#Load mesh
device = torch.device("cuda:0")
mesh = IO().load_mesh("mesh.ply").to(device)

#GT Pose for instance 176
R=  torch.tensor([0.66307002, 0.74850100, 0.00921593, 0.50728703, -0.44026601, -0.74082798, -0.55045301, 0.49589601, -0.67163098])
T= torch.tensor([42.36749640, 1.84263252, 768.28001229])

#Apply fix #294
RT = torch.zeros((4,4))
RT[3,3] = 1
RT[:3,:3] = R.reshape(3,3)
RT[:3,3] = T/1000.0

Rz = torch.tensor([[-1,0, 0, 0],
                [0, -1, 0, 0],
                [0, 0, 1, 0],
                [0, 0, 0, 1]]).float()

RT = torch.matmul(Rz, RT)

R = RT[:3,:3].reshape(1,3,3)
T = RT[:3,3].reshape(1,3)

f = torch.tensor((f_x, f_y), dtype=torch.float32).unsqueeze(0)
p = torch.tensor((p_x, p_y), dtype=torch.float32).unsqueeze(0)
img_size= torch.tensor((h, w), dtype=torch.float32).unsqueeze(0)

lights = AmbientLights(device=device)

camera = PerspectiveCameras(
    R=R, T=T,
    focal_length=f,
    principal_point=p,
        image_size=((h, w),),
        device=device,
        in_ndc=False)

# Set Renderer Parameters
raster_settings = RasterizationSettings(
    image_size=(h,w), 
    blur_radius=0.0, 
    faces_per_pixel=1,
    max_faces_per_bin = mesh.faces_packed().shape[0],
    perspective_correct = True
)

rasterizer = MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    )

renderer = MeshRenderer(
    rasterizer,
    shader=SoftPhongShader(
        device=device, 
        cameras=camera,
        lights=lights,
    )
)

# Generate rendered image
target_images= renderer(mesh, cameras=camera, lights=lights)

img = target_images[0, ..., :3]

imgray = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2GRAY)
ret, mask = cv2.threshold(imgray, 1, 255, 0)
bg = None
while bg is None or len(bg.shape) < 3:
    bg_pth = "bg.png"
    bg = cv2.cvtColor(cv2.imread(bg_pth), cv2.COLOR_BGR2RGB)
    if len(bg.shape) < 3:
        bg = None
        continue
    bg_h, bg_w, _ = bg.shape
    if bg_h < h:
        new_w = int(float(h) / bg_h * bg_w)
        bg = cv2.resize(bg, (new_w, h))
    bg_h, bg_w, _ = bg.shape
    if bg_w < w:
        new_h = int(float(w) / bg_w * bg_h)
        bg = cv2.resize(bg, (w, new_h))
    bg_h, bg_w, _ = bg.shape
    if bg_h > h:
        sh = randint(0, bg_h - h)
        bg = bg[sh : sh + h, :, :]
    bg_h, bg_w, _ = bg.shape
    if bg_w > w:
        sw = randint(0, bg_w - w)
        bg = bg[:, sw : sw + w, :]

    bg = cv2.cvtColor(bg, cv2.COLOR_RGB2BGR)

msk_3c = np.repeat(mask[:, :, None], 3, axis=2)

msk_3c = msk_3c>0

rgb =   img.cpu().numpy() * (msk_3c).astype(bg.dtype) + bg * (msk_3c == 0).astype(bg.dtype)

cv2.imwrite('test.jpg', rgb)
  1. The exact command(s) you ran:
    python test.py

Files Used:

GT/Background image: bg

Mesh: https://drive.google.com/file/d/1kIw9dypvxsJ-xgI573RL6s-ncf7e5L2a/view?usp=sharing

@bottler
Copy link
Contributor

bottler commented Nov 14, 2021

I think you need to pay close attention to camera conventions in PyTorch3D and the dataset. Issue #18 may be useful.

@nikhilaravi nikhilaravi added the how to How to use PyTorch3D in my project label Nov 15, 2021
@manasburagohain
Copy link
Author

manasburagohain commented Nov 16, 2021

HI @bottler and @nikhilaravi. Thanks for the quick response. I do understand that it is a camera convention issue. And I have the following observations regarding that:

  • Both the ground truth for LineMod and PyTorch3D conventions follow the right hand coordinate system.
  • Since pytorch3d convention assumes +ve x-axis to the left and z-axis going into the screen, I attempted to rotate the Ground truth pose by pre multiplying with a 180 rotation about the y-axis. However, this results in the object being placed behind the camera's field of view.
  • I adapted the solution in Issue How to use SfMPerspectiveCameras for renderer? #18 for the latest pytorch3d version, however, it also does not solve this misalignment between the ground truth pose and the rendered object using said post.

Any input from your end would be greatly appreciated.

@sailor-z
Copy link

sailor-z commented Nov 18, 2021

Hi,
May I ask why you divided T by 1000?

I also found a strange point.
When I was using your code, the rendered result seems good except for that pose issue.
image
But I got this strange result after slightly changing the code.
image

cameras = PerspectiveCameras(focal_length=(focal_length,), principal_point=(principal_point,), image_size = (image_size,), device=device, in_ndc=False)

raster_settings = RasterizationSettings(
    image_size=image_size,
    blur_radius=0.0,
    faces_per_pixel=1,
    perspective_correct=True
)

lights = AmbientLights(device=device)

phong_renderer = MeshRendererWithDepth(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)#,HardPhongShader
)

image, depth = phong_renderer(meshes_world=meshes[id], R=R, T=T.squeeze(-1)/1000, max_faces_per_bin=meshes[id].faces_packed().shape[0])

I have no idea what the difference is.

@manasburagohain
Copy link
Author

manasburagohain commented Nov 18, 2021

The division by 1000 is for changing the scale. The vertices for the mesh are defined in meters while the translation vector is defined in millimeters.

The weird rendering is due to overflow of max_faces_per_bin attribute in the rasterization settings. Setting it to the total number of faces in the mesh fixed it for me.

@sailor-z
Copy link

sailor-z commented Nov 19, 2021

Thanks a lot! It seems like we cannot perform a batched rendering because max_faces_per_bin has to be specified for each mesh independently.

@manasburagohain
Copy link
Author

@sailor-z Maybe you can set the attribute to be the max faces across the whole batch which would allow it to render correctly.

Were you able to figure out the conventions issue with the renderer?

@sailor-z
Copy link

@manasburagohain When I set the attribute like this

image, depth = phong_renderer(meshes_world=meshes[id], R=R, T=T.squeeze(-1)/1000, max_faces_per_bin=meshes[id].faces_packed().shape[0]),

it doesn't work. I have used the solution in #18 with pytorch3d 0.6.0, but the pose issue still exists.

@wangg12
Copy link

wangg12 commented Nov 23, 2021

Hey @manasburagohain , I think you can just change one line:

R = RT[:3, :3].t().reshape(1, 3, 3)

This is simply because the rotation (and RT transformation) in pytorch3d is applied from right to left. So the rotation and transformation matrices in pytorch3d are transposed compared with those in opencv.

Shown below is what I obtained after this change (I additionally changed the background color for rendering as black, my full code is available here https://github.com/wangg12/pytorch3d_render_linemod/blob/master/test.py):
(rendered vs real vs rendered+real)
image

@sailor-z
Copy link

@wangg12 It works for me. Thanks a lot!

@manasburagohain
Copy link
Author

@wangg12 Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

5 participants