Skip to content

Inconsistent Outputs: Pretrained Weights (28 Feb 2025 vs. 12 Jun 2024) #128

@DengKaiCQ

Description

@DengKaiCQ

Hi, my friend.

Your work on the metric depth estimation model performs well on large-scale outdoor sequence datasets like KITTI. My recent project GigaSLAM uses the UniDepth v2 module and achieves really good performance on large-scale outdoor sequence datasets. Recently I found that UniDepth v2 has undergone several updates on both GitHub and HuggingFace Hub.

As mentioned in this link, however, the latest version of code and pretrained weights performs differently compared to the old version. There were 40 files with 1.5k lines changed in GitHub commit bebc4b2 on 10 Mar 2025, which is quite a big update. I guess that somethings unnoticed changes that may affect the performance have been made in this commit.

To elaborate, I tested the latest UniDepth code (8d8cfe4 on GitHub) with the latest pretrained weights (52b349b on Hugging Face) and the older version (the code was cloned on 04 Nov 2024 - most likely bebc4b2 on GitHub, and weights 1d0d3c5 on Hugging Face). Both ran on KITTI Sequence 06 (~1100 images). The outputs differ significantly.

The following code was used in this small exp (latest unidepth-v2-vitl14, 8d8cfe4 on GitHub and 52b349b on Hugging Face)

import os
import time
import numpy as np
import torch
from PIL import Image
from unidepth.models import UniDepthV2
from tqdm import tqdm
import matplotlib.pyplot as plt

snapshot_dir = "lpiccinelli/unidepth-v2-vitl14"  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniDepthV2.from_pretrained(snapshot_dir)
model = model.to(device)

fx, fy, cx, cy = 707.0912, 707.0912, 601.8873, 183.1104  # kitti 04-10
intrinsics = torch.tensor([[fx, 0, cx],
                         [0, fy, cy],
                         [0, 0, 1]], dtype=torch.float32).to(device)

image_dir = "/media/deng/Data/KITTIdataset/data_odometry_color/dataset/sequences/06/image_2"
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)

min_depths = []
median_depths = []
mean_depths = []
max_depths = []
frame_ids = []

image_files = [f for f in os.listdir(image_dir) if f.endswith(".png")]
image_files.sort()

for i, filename in enumerate(tqdm(image_files, desc="Processing images", unit="image")):
    image_path = os.path.join(image_dir, filename)

    rgb = Image.open(image_path).convert("RGB")
    rgb = torch.from_numpy(np.array(rgb)).permute(2, 0, 1).to(device)  # C, H, W

    begin = time.time()
    predictions = model.infer(rgb, intrinsics)
    depth = predictions["depth"].cpu().detach().numpy()
    end = time.time()
    
    min_depth = depth.min()
    median_depth = np.median(depth)
    mean_depth = depth.mean()
    max_depth = depth.max()
    
    min_depths.append(min_depth)
    median_depths.append(median_depth)
    mean_depths.append(mean_depth)
    max_depths.append(max_depth)
    frame_ids.append(i)
    
    print(f'Frame {i}: inference time: {(end - begin) * 1000:.2f} ms')
    print(f'min: {min_depth:.2f}, median: {median_depth:.2f}, mean: {mean_depth:.2f}, max: {max_depth:.2f}')

global_min_stats = {
    'min': np.min(min_depths),
    'median': np.median(min_depths),
    'mean': np.mean(min_depths),
    'max': np.max(min_depths)
}

global_median_stats = {
    'min': np.min(median_depths),
    'median': np.median(median_depths),
    'mean': np.mean(median_depths),
    'max': np.max(median_depths)
}

global_mean_stats = {
    'min': np.min(mean_depths),
    'median': np.median(mean_depths),
    'mean': np.mean(mean_depths),
    'max': np.max(mean_depths)
}

global_max_stats = {
    'min': np.min(max_depths),
    'median': np.median(max_depths),
    'mean': np.mean(max_depths),
    'max': np.max(max_depths)
}

plt.figure(figsize=(24, 6))

plt.subplot(1, 4, 1)
plt.scatter(frame_ids, min_depths, color='blue', s=10)
plt.xlabel('frame id')
plt.ylabel('Depth Value')
plt.title(f"Minimum Depth\nmin={global_min_stats['min']:.2f} med={global_min_stats['median']:.2f}\nmean={global_min_stats['mean']:.2f} max={global_min_stats['max']:.2f}")
plt.grid(True)

plt.subplot(1, 4, 2)
plt.scatter(frame_ids, median_depths, color='green', s=10)
plt.xlabel('frame id')
plt.ylabel('Depth Value')
plt.title(f"Median Depth\nmin={global_median_stats['min']:.2f} med={global_median_stats['median']:.2f}\nmean={global_median_stats['mean']:.2f} max={global_median_stats['max']:.2f}")
plt.grid(True)

plt.subplot(1, 4, 3)
plt.scatter(frame_ids, mean_depths, color='orange', s=10)
plt.xlabel('frame id')
plt.ylabel('Depth Value')
plt.title(f"Mean Depth\nmin={global_mean_stats['min']:.2f} med={global_mean_stats['median']:.2f}\nmean={global_mean_stats['mean']:.2f} max={global_mean_stats['max']:.2f}")
plt.grid(True)

plt.subplot(1, 4, 4)
plt.scatter(frame_ids, max_depths, color='red', s=10)
plt.xlabel('frame id')
plt.ylabel('Depth Value')
plt.title(f"Maximum Depth\nmin={global_max_stats['min']:.2f} med={global_max_stats['median']:.2f}\nmean={global_max_stats['mean']:.2f} max={global_max_stats['max']:.2f}")
plt.grid(True)

plt.tight_layout()

plot_path = os.path.join(output_dir, 'depth_statistics_4plots.png')
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()

print(f"Depth statistics plot saved to {plot_path}")

the older version (bebc4b2 on GitHub i guess and 1d0d3c5 on Hugging Face) uses the same small exp code but only changes the lines below:

...

snapshot_dir = "lpiccinelli/unidepth-v2-vitl14"  
commit_hash = "1d0d3c52f60b5164629d279bb9a7546458e6dcc4"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniDepthV2.from_pretrained(
    snapshot_dir,
    revision=commit_hash
)
model = model.to(device)

...
(the rest remains the same)

The images of KITTI Seq 06 are a bunch of car driving scene like this:

The scale of the depth outputs should remain consistent across the entire sequence.

You can see the outputs of this small experiment below, along with the point clouds generated from each depth output and the GT poses.

  1. Latest version (unidepth-v2-vitl14, 8d8cfe4 on GitHub and 52b349b on Hugging Face):

As the frame_id of the input image sequence increases, you can clearly see that the depth scale decreases. This leads to inconsistent depth scales across different frames within the same sequence. This phenomenon can be seen more intuitively in the point cloud visualizations.

Image

Image

  1. Older version (unidepth-v2-vitl14, bebc4b2 on GitHub and 1d0d3c5 on Hugging Face):

Although the statistics graph appears slightly noisy, the depth scale remains consistent. You can see the corresponding point cloud with GT poses — it performed very well on this sequence! We can clearly identify the cars and buildings.

Image

Image

Based on this small experiment, we may find that there are some unknown differences between the older and newer versions of UniDepth v2.

I wonder if you could shed some light on the specific changes that might affect the consistency of depth scale across frames. I'm not sure if this is an intended behavior or just an edge case in my setup. Were there any known modifications to the scale process, the training procedure, or post-processing logic that could explain this behavior? For now a quick solution of this problem is to use the older verion of UniDepth v2.

To sum up, the scale consistency in 1d0d3c5 works perfectly for robotic applications, while the new version 52b349b has drifting scales that may break downstream tasks.

I really appreciate your work on UniDepth and look forward to your insights!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions